Skip to content

Commit

Permalink
Provides ExitCode Shutdowner Option; and Wait method to receive it. (#…
Browse files Browse the repository at this point in the history
…989)

This PR provides an option for those who take dependencies on the
`Shutdowner` interface to call the `Shutdown` method with an `ExitCode`
option, in addition it add a `Wait` method to the application to allow
for main programs to wait for the application to be shutdown and to exit
with a given exit code.

Please note that this PR refactors the existing signal relay
functionality, and alters application lifecycle slightly. Now `Done`
will not receive an `os.Signal` on the channel it returns _unless_ a
given FX application has been started.

Co-authored-by: Sung Yoon Whang <[email protected]>
  • Loading branch information
jasonmills and sywhang authored Dec 7, 2022
1 parent 94f1a09 commit b379e13
Show file tree
Hide file tree
Showing 7 changed files with 377 additions and 39 deletions.
16 changes: 12 additions & 4 deletions annotated_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1485,7 +1485,8 @@ func assertApp(
invoked *bool,
) {
t.Helper()
ctx := context.Background()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
assert.False(t, *started)
require.NoError(t, app.Start(ctx))
assert.True(t, *started)
Expand Down Expand Up @@ -1517,8 +1518,11 @@ func TestHookAnnotations(t *testing.T) {
t.Run("with hook on invoke", func(t *testing.T) {
t.Parallel()

var started bool
var invoked bool
var (
started bool
stopped bool
invoked bool
)
hook := fx.Annotate(
func() {
invoked = true
Expand All @@ -1527,10 +1531,14 @@ func TestHookAnnotations(t *testing.T) {
started = true
return nil
}),
fx.OnStop(func(context.Context) error {
stopped = true
return nil
}),
)
app := fxtest.New(t, fx.Invoke(hook))

assertApp(t, app, &started, nil, &invoked)
assertApp(t, app, &started, &stopped, &invoked)
})

t.Run("depend on result interface of target", func(t *testing.T) {
Expand Down
43 changes: 39 additions & 4 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -619,9 +619,13 @@ func (app *App) Start(ctx context.Context) (err error) {
})
}

func (app *App) start(ctx context.Context) error {
if err := app.lifecycle.Start(ctx); err != nil {
// Start failed, rolling back.
// withRollback will execute an anonymous function with a given context.
// if the anon func returns an error, rollback methods will be called and related events emitted
func (app *App) withRollback(
ctx context.Context,
f func(context.Context) error,
) error {
if err := f(ctx); err != nil {
app.log().LogEvent(&fxevent.RollingBack{StartErr: err})

stopErr := app.lifecycle.Stop(ctx)
Expand All @@ -633,9 +637,20 @@ func (app *App) start(ctx context.Context) error {

return err
}

return nil
}

func (app *App) start(ctx context.Context) error {
return app.withRollback(ctx, func(ctx context.Context) error {
if err := app.lifecycle.Start(ctx); err != nil {
return err
}
app.receivers.Start(ctx)
return nil
})
}

// Stop gracefully stops the application. It executes any registered OnStop
// hooks in reverse order, so that each constructor's stop hooks are called
// before its dependencies' stop hooks.
Expand All @@ -648,9 +663,14 @@ func (app *App) Stop(ctx context.Context) (err error) {
app.log().LogEvent(&fxevent.Stopped{Err: err})
}()

cb := func(ctx context.Context) error {
defer app.receivers.Stop(ctx)
return app.lifecycle.Stop(ctx)
}

return withTimeout(ctx, &withTimeoutParams{
hook: _onStopHook,
callback: app.lifecycle.Stop,
callback: cb,
lifecycle: app.lifecycle,
log: app.log(),
})
Expand All @@ -663,10 +683,25 @@ func (app *App) Stop(ctx context.Context) (err error) {
//
// Alternatively, a signal can be broadcast to all done channels manually by
// using the Shutdown functionality (see the Shutdowner documentation for details).
//
// Note: The channel Done returns will not receive a signal unless the application
// as been started via Start or Run.
func (app *App) Done() <-chan os.Signal {
return app.receivers.Done()
}

// Wait returns a channel of [ShutdownSignal] to block on after starting the
// application and function, similar to [App.Done], but with a minor difference.
// Should an ExitCode be provided as a [ShutdownOption] to
// the Shutdowner Shutdown method, the exit code will be available as part
// of the ShutdownSignal struct.
//
// Should the app receive a SIGTERM or SIGINT, the given
// signal will be populated in the ShutdownSignal struct.
func (app *App) Wait() <-chan ShutdownSignal {
return app.receivers.Wait()
}

// StartTimeout returns the configured startup timeout. Apps default to using
// DefaultTimeout, but users can configure this behavior using the
// StartTimeout option.
Expand Down
11 changes: 7 additions & 4 deletions app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,9 @@ func TestAppStart(t *testing.T) {
t.Run("StartTwiceWithHooksErrors", func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

app := fxtest.New(t,
Invoke(func(lc Lifecycle) {
lc.Append(Hook{
Expand All @@ -1293,13 +1296,13 @@ func TestAppStart(t *testing.T) {
})
}),
)
assert.NoError(t, app.Start(context.Background()))
err := app.Start(context.Background())
assert.NoError(t, app.Start(ctx))
err := app.Start(ctx)
if assert.Error(t, err) {
assert.ErrorContains(t, err, "attempted to start lifecycle when in state: started")
}
app.Stop(context.Background())
assert.NoError(t, app.Start(context.Background()))
app.Stop(ctx)
assert.NoError(t, app.Start(ctx))
})
}

Expand Down
63 changes: 61 additions & 2 deletions shutdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@

package fx

import (
"context"
"time"
)

// Shutdowner provides a method that can manually trigger the shutdown of the
// application by sending a signal to all open Done channels. Shutdowner works
// on applications using Run as well as Start, Done, and Stop. The Shutdowner is
Expand All @@ -34,8 +39,42 @@ type ShutdownOption interface {
apply(*shutdowner)
}

type exitCodeOption int

func (code exitCodeOption) apply(s *shutdowner) {
s.exitCode = int(code)
}

var _ ShutdownOption = exitCodeOption(0)

// ExitCode is a [ShutdownOption] that may be passed to the Shutdown method of the
// [Shutdowner] interface.
// The given integer exit code will be broadcasted to any receiver waiting
// on a [ShutdownSignal] from the [Wait] method.
func ExitCode(code int) ShutdownOption {
return exitCodeOption(code)
}

type shutdownTimeoutOption time.Duration

func (to shutdownTimeoutOption) apply(s *shutdowner) {
s.shutdownTimeout = time.Duration(to)
}

var _ ShutdownOption = shutdownTimeoutOption(0)

// ShutdownTimeout is a [ShutdownOption] that allows users to specify a timeout
// for a given call to Shutdown method of the [Shutdowner] interface. As the
// Shutdown method will block while waiting for a signal receiver relay
// goroutine to stop.
func ShutdownTimeout(timeout time.Duration) ShutdownOption {
return shutdownTimeoutOption(timeout)
}

type shutdowner struct {
app *App
app *App
exitCode int
shutdownTimeout time.Duration
}

// Shutdown broadcasts a signal to all of the application's Done channels
Expand All @@ -44,7 +83,27 @@ type shutdowner struct {
// In practice this means Shutdowner.Shutdown should not be called from an
// fx.Invoke, but from a fx.Lifecycle.OnStart hook.
func (s *shutdowner) Shutdown(opts ...ShutdownOption) error {
return s.app.receivers.Broadcast(ShutdownSignal{Signal: _sigTERM})
for _, opt := range opts {
opt.apply(s)
}

ctx := context.Background()

if s.shutdownTimeout != time.Duration(0) {
c, cancel := context.WithTimeout(
context.Background(),
s.shutdownTimeout,
)
defer cancel()
ctx = c
}

defer s.app.receivers.Stop(ctx)

return s.app.receivers.Broadcast(ShutdownSignal{
Signal: _sigTERM,
ExitCode: s.exitCode,
})
}

func (app *App) shutdowner() Shutdowner {
Expand Down
43 changes: 42 additions & 1 deletion shutdown_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ package fx_test

import (
"context"
"fmt"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -61,12 +63,14 @@ func TestShutdown(t *testing.T) {
)

done := app.Done()
wait := app.Wait()
defer app.RequireStart().RequireStop()
assert.NoError(t, s.Shutdown(), "error returned from first shutdown call")

assert.EqualError(t, s.Shutdown(), "send terminated signal: 1/1 channels are blocked",
assert.EqualError(t, s.Shutdown(), "send terminated signal: 2/2 channels are blocked",
"unexpected error returned when shutdown is called with a blocked channel")
assert.NotNil(t, <-done, "done channel did not receive signal")
assert.NotNil(t, <-wait, "wait channel did not receive signal")
})

t.Run("shutdown app before calling Done()", func(t *testing.T) {
Expand All @@ -87,6 +91,43 @@ func TestShutdown(t *testing.T) {
assert.NotNil(t, <-done1, "done channel 1 did not receive signal")
assert.NotNil(t, <-done2, "done channel 2 did not receive signal")
})

t.Run("with exit code", func(t *testing.T) {
t.Parallel()
var s fx.Shutdowner
app := fxtest.New(
t,
fx.Populate(&s),
)

require.NoError(t, app.Start(context.Background()), "error starting app")
assert.NoError(t, s.Shutdown(fx.ExitCode(2)), "error in app shutdown")
wait := <-app.Wait()
defer app.Stop(context.Background())
require.Equal(t, 2, wait.ExitCode)
})

t.Run("with exit code and multiple Wait", func(t *testing.T) {
t.Parallel()
var s fx.Shutdowner
app := fxtest.New(
t,
fx.Populate(&s),
)

require.NoError(t, app.Start(context.Background()), "error starting app")
defer require.NoError(t, app.Stop(context.Background()))

for i := 0; i < 10; i++ {
t.Run(fmt.Sprintf("Wait %v", i), func(t *testing.T) {
t.Parallel()
wait := <-app.Wait()
require.Equal(t, 2, wait.ExitCode)
})
}

assert.NoError(t, s.Shutdown(fx.ExitCode(2), fx.ShutdownTimeout(time.Second)))
})
}

func TestDataRace(t *testing.T) {
Expand Down
Loading

0 comments on commit b379e13

Please sign in to comment.