From b379e132182b47a3f2400212e136898575665154 Mon Sep 17 00:00:00 2001 From: jasonmills Date: Wed, 7 Dec 2022 14:09:30 -0800 Subject: [PATCH] Provides ExitCode Shutdowner Option; and Wait method to receive it. (#989) This PR provides an option for those who take dependencies on the `Shutdowner` interface to call the `Shutdown` method with an `ExitCode` option, in addition it add a `Wait` method to the application to allow for main programs to wait for the application to be shutdown and to exit with a given exit code. Please note that this PR refactors the existing signal relay functionality, and alters application lifecycle slightly. Now `Done` will not receive an `os.Signal` on the channel it returns _unless_ a given FX application has been started. Co-authored-by: Sung Yoon Whang --- annotated_test.go | 16 +++-- app.go | 43 +++++++++++-- app_test.go | 11 ++-- shutdown.go | 63 +++++++++++++++++- shutdown_test.go | 43 ++++++++++++- signal.go | 158 +++++++++++++++++++++++++++++++++++++++++++--- signal_test.go | 82 +++++++++++++++++++----- 7 files changed, 377 insertions(+), 39 deletions(-) diff --git a/annotated_test.go b/annotated_test.go index a45fd8b94..7eefddde4 100644 --- a/annotated_test.go +++ b/annotated_test.go @@ -1485,7 +1485,8 @@ func assertApp( invoked *bool, ) { t.Helper() - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() assert.False(t, *started) require.NoError(t, app.Start(ctx)) assert.True(t, *started) @@ -1517,8 +1518,11 @@ func TestHookAnnotations(t *testing.T) { t.Run("with hook on invoke", func(t *testing.T) { t.Parallel() - var started bool - var invoked bool + var ( + started bool + stopped bool + invoked bool + ) hook := fx.Annotate( func() { invoked = true @@ -1527,10 +1531,14 @@ func TestHookAnnotations(t *testing.T) { started = true return nil }), + fx.OnStop(func(context.Context) error { + stopped = true + return nil + }), ) app := fxtest.New(t, fx.Invoke(hook)) - assertApp(t, app, &started, nil, &invoked) + assertApp(t, app, &started, &stopped, &invoked) }) t.Run("depend on result interface of target", func(t *testing.T) { diff --git a/app.go b/app.go index ad2b08ce7..c8a728cea 100644 --- a/app.go +++ b/app.go @@ -619,9 +619,13 @@ func (app *App) Start(ctx context.Context) (err error) { }) } -func (app *App) start(ctx context.Context) error { - if err := app.lifecycle.Start(ctx); err != nil { - // Start failed, rolling back. +// withRollback will execute an anonymous function with a given context. +// if the anon func returns an error, rollback methods will be called and related events emitted +func (app *App) withRollback( + ctx context.Context, + f func(context.Context) error, +) error { + if err := f(ctx); err != nil { app.log().LogEvent(&fxevent.RollingBack{StartErr: err}) stopErr := app.lifecycle.Stop(ctx) @@ -633,9 +637,20 @@ func (app *App) start(ctx context.Context) error { return err } + return nil } +func (app *App) start(ctx context.Context) error { + return app.withRollback(ctx, func(ctx context.Context) error { + if err := app.lifecycle.Start(ctx); err != nil { + return err + } + app.receivers.Start(ctx) + return nil + }) +} + // Stop gracefully stops the application. It executes any registered OnStop // hooks in reverse order, so that each constructor's stop hooks are called // before its dependencies' stop hooks. @@ -648,9 +663,14 @@ func (app *App) Stop(ctx context.Context) (err error) { app.log().LogEvent(&fxevent.Stopped{Err: err}) }() + cb := func(ctx context.Context) error { + defer app.receivers.Stop(ctx) + return app.lifecycle.Stop(ctx) + } + return withTimeout(ctx, &withTimeoutParams{ hook: _onStopHook, - callback: app.lifecycle.Stop, + callback: cb, lifecycle: app.lifecycle, log: app.log(), }) @@ -663,10 +683,25 @@ func (app *App) Stop(ctx context.Context) (err error) { // // Alternatively, a signal can be broadcast to all done channels manually by // using the Shutdown functionality (see the Shutdowner documentation for details). +// +// Note: The channel Done returns will not receive a signal unless the application +// as been started via Start or Run. func (app *App) Done() <-chan os.Signal { return app.receivers.Done() } +// Wait returns a channel of [ShutdownSignal] to block on after starting the +// application and function, similar to [App.Done], but with a minor difference. +// Should an ExitCode be provided as a [ShutdownOption] to +// the Shutdowner Shutdown method, the exit code will be available as part +// of the ShutdownSignal struct. +// +// Should the app receive a SIGTERM or SIGINT, the given +// signal will be populated in the ShutdownSignal struct. +func (app *App) Wait() <-chan ShutdownSignal { + return app.receivers.Wait() +} + // StartTimeout returns the configured startup timeout. Apps default to using // DefaultTimeout, but users can configure this behavior using the // StartTimeout option. diff --git a/app_test.go b/app_test.go index 7beae0157..667294c79 100644 --- a/app_test.go +++ b/app_test.go @@ -1285,6 +1285,9 @@ func TestAppStart(t *testing.T) { t.Run("StartTwiceWithHooksErrors", func(t *testing.T) { t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + app := fxtest.New(t, Invoke(func(lc Lifecycle) { lc.Append(Hook{ @@ -1293,13 +1296,13 @@ func TestAppStart(t *testing.T) { }) }), ) - assert.NoError(t, app.Start(context.Background())) - err := app.Start(context.Background()) + assert.NoError(t, app.Start(ctx)) + err := app.Start(ctx) if assert.Error(t, err) { assert.ErrorContains(t, err, "attempted to start lifecycle when in state: started") } - app.Stop(context.Background()) - assert.NoError(t, app.Start(context.Background())) + app.Stop(ctx) + assert.NoError(t, app.Start(ctx)) }) } diff --git a/shutdown.go b/shutdown.go index eebb5f1b5..aa81e68d3 100644 --- a/shutdown.go +++ b/shutdown.go @@ -20,6 +20,11 @@ package fx +import ( + "context" + "time" +) + // Shutdowner provides a method that can manually trigger the shutdown of the // application by sending a signal to all open Done channels. Shutdowner works // on applications using Run as well as Start, Done, and Stop. The Shutdowner is @@ -34,8 +39,42 @@ type ShutdownOption interface { apply(*shutdowner) } +type exitCodeOption int + +func (code exitCodeOption) apply(s *shutdowner) { + s.exitCode = int(code) +} + +var _ ShutdownOption = exitCodeOption(0) + +// ExitCode is a [ShutdownOption] that may be passed to the Shutdown method of the +// [Shutdowner] interface. +// The given integer exit code will be broadcasted to any receiver waiting +// on a [ShutdownSignal] from the [Wait] method. +func ExitCode(code int) ShutdownOption { + return exitCodeOption(code) +} + +type shutdownTimeoutOption time.Duration + +func (to shutdownTimeoutOption) apply(s *shutdowner) { + s.shutdownTimeout = time.Duration(to) +} + +var _ ShutdownOption = shutdownTimeoutOption(0) + +// ShutdownTimeout is a [ShutdownOption] that allows users to specify a timeout +// for a given call to Shutdown method of the [Shutdowner] interface. As the +// Shutdown method will block while waiting for a signal receiver relay +// goroutine to stop. +func ShutdownTimeout(timeout time.Duration) ShutdownOption { + return shutdownTimeoutOption(timeout) +} + type shutdowner struct { - app *App + app *App + exitCode int + shutdownTimeout time.Duration } // Shutdown broadcasts a signal to all of the application's Done channels @@ -44,7 +83,27 @@ type shutdowner struct { // In practice this means Shutdowner.Shutdown should not be called from an // fx.Invoke, but from a fx.Lifecycle.OnStart hook. func (s *shutdowner) Shutdown(opts ...ShutdownOption) error { - return s.app.receivers.Broadcast(ShutdownSignal{Signal: _sigTERM}) + for _, opt := range opts { + opt.apply(s) + } + + ctx := context.Background() + + if s.shutdownTimeout != time.Duration(0) { + c, cancel := context.WithTimeout( + context.Background(), + s.shutdownTimeout, + ) + defer cancel() + ctx = c + } + + defer s.app.receivers.Stop(ctx) + + return s.app.receivers.Broadcast(ShutdownSignal{ + Signal: _sigTERM, + ExitCode: s.exitCode, + }) } func (app *App) shutdowner() Shutdowner { diff --git a/shutdown_test.go b/shutdown_test.go index a6d0ad508..1f322a2cb 100644 --- a/shutdown_test.go +++ b/shutdown_test.go @@ -22,8 +22,10 @@ package fx_test import ( "context" + "fmt" "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -61,12 +63,14 @@ func TestShutdown(t *testing.T) { ) done := app.Done() + wait := app.Wait() defer app.RequireStart().RequireStop() assert.NoError(t, s.Shutdown(), "error returned from first shutdown call") - assert.EqualError(t, s.Shutdown(), "send terminated signal: 1/1 channels are blocked", + assert.EqualError(t, s.Shutdown(), "send terminated signal: 2/2 channels are blocked", "unexpected error returned when shutdown is called with a blocked channel") assert.NotNil(t, <-done, "done channel did not receive signal") + assert.NotNil(t, <-wait, "wait channel did not receive signal") }) t.Run("shutdown app before calling Done()", func(t *testing.T) { @@ -87,6 +91,43 @@ func TestShutdown(t *testing.T) { assert.NotNil(t, <-done1, "done channel 1 did not receive signal") assert.NotNil(t, <-done2, "done channel 2 did not receive signal") }) + + t.Run("with exit code", func(t *testing.T) { + t.Parallel() + var s fx.Shutdowner + app := fxtest.New( + t, + fx.Populate(&s), + ) + + require.NoError(t, app.Start(context.Background()), "error starting app") + assert.NoError(t, s.Shutdown(fx.ExitCode(2)), "error in app shutdown") + wait := <-app.Wait() + defer app.Stop(context.Background()) + require.Equal(t, 2, wait.ExitCode) + }) + + t.Run("with exit code and multiple Wait", func(t *testing.T) { + t.Parallel() + var s fx.Shutdowner + app := fxtest.New( + t, + fx.Populate(&s), + ) + + require.NoError(t, app.Start(context.Background()), "error starting app") + defer require.NoError(t, app.Stop(context.Background())) + + for i := 0; i < 10; i++ { + t.Run(fmt.Sprintf("Wait %v", i), func(t *testing.T) { + t.Parallel() + wait := <-app.Wait() + require.Equal(t, 2, wait.ExitCode) + }) + } + + assert.NoError(t, s.Shutdown(fx.ExitCode(2), fx.ShutdownTimeout(time.Second))) + }) } func TestDataRace(t *testing.T) { diff --git a/signal.go b/signal.go index 4c0b28763..79dcfeb41 100644 --- a/signal.go +++ b/signal.go @@ -21,15 +21,22 @@ package fx import ( + "context" "fmt" "os" "os/signal" "sync" ) -// ShutdownSignal is a signal that caused the application to exit. +// ShutdownSignal represents a signal to be written to Wait or Done. +// Should a user call the Shutdown method via the Shutdowner interface with +// a provided ExitCode, that exit code will be populated in the ExitCode field. +// +// Should the application receive an operating system signal, +// the Signal field will be populated with the received os.Signal. type ShutdownSignal struct { - Signal os.Signal + Signal os.Signal + ExitCode int } // String will render a ShutdownSignal type as a string suitable for printing. @@ -38,17 +45,103 @@ func (sig ShutdownSignal) String() string { } func newSignalReceivers() signalReceivers { - return signalReceivers{notify: signal.Notify} + return signalReceivers{ + notify: signal.Notify, + signals: make(chan os.Signal, 1), + } } type signalReceivers struct { - m sync.Mutex - last *ShutdownSignal - done []chan os.Signal + // this mutex protects writes and reads of this struct to prevent + // race conditions in a parallel execution pattern + m sync.Mutex + + // our os.Signal channel we relay from + signals chan os.Signal + // when written to, will instruct the signal relayer to shutdown + shutdown chan struct{} + // is written to when signal relay has finished shutting down + finished chan struct{} + + // this stub allows us to unit test signal relay functionality notify func(c chan<- os.Signal, sig ...os.Signal) + + // last will contain a pointer to the last ShutdownSignal received, or + // nil if none, if a new channel is created by Wait or Done, this last + // signal will be immediately written to, this allows Wait or Done state + // to be read after application stop + last *ShutdownSignal + + // contains channels created by Done + done []chan os.Signal + + // contains channels created by Wait + wait []chan ShutdownSignal +} + +func (recv *signalReceivers) relayer(ctx context.Context) { + defer func() { + recv.finished <- struct{}{} + }() + + select { + case <-recv.shutdown: + return + case <-ctx.Done(): + return + case signal := <-recv.signals: + recv.Broadcast(ShutdownSignal{ + Signal: signal, + }) + } +} + +// running returns true if the the signal relay go-routine is running. +// this method must be invoked under locked mutex to avoid race condition. +func (recv *signalReceivers) running() bool { + return recv.shutdown != nil && recv.finished != nil } -func (recv *signalReceivers) Done() <-chan os.Signal { +func (recv *signalReceivers) Start(ctx context.Context) { + recv.m.Lock() + defer recv.m.Unlock() + + // if the receiver has already been started; don't start it again + if recv.running() { + return + } + + recv.last = nil + recv.finished = make(chan struct{}, 1) + recv.shutdown = make(chan struct{}, 1) + recv.notify(recv.signals, os.Interrupt, _sigINT, _sigTERM) + go recv.relayer(ctx) +} + +func (recv *signalReceivers) Stop(ctx context.Context) error { + recv.m.Lock() + defer recv.m.Unlock() + + // if the relayer is not running; return nil error + if !recv.running() { + return nil + } + + recv.shutdown <- struct{}{} + + select { + case <-ctx.Done(): + return ctx.Err() + case <-recv.finished: + close(recv.shutdown) + close(recv.finished) + recv.shutdown = nil + recv.finished = nil + return nil + } +} + +func (recv *signalReceivers) Done() chan os.Signal { recv.m.Lock() defer recv.m.Unlock() @@ -62,17 +155,35 @@ func (recv *signalReceivers) Done() <-chan os.Signal { ch <- recv.last.Signal } - recv.notify(ch, os.Interrupt, _sigINT, _sigTERM) recv.done = append(recv.done, ch) return ch } +func (recv *signalReceivers) Wait() chan ShutdownSignal { + recv.m.Lock() + defer recv.m.Unlock() + + ch := make(chan ShutdownSignal, 1) + + if recv.last != nil { + ch <- *recv.last + } + + recv.wait = append(recv.wait, ch) + return ch +} + func (recv *signalReceivers) Broadcast(signal ShutdownSignal) error { recv.m.Lock() defer recv.m.Unlock() + recv.last = &signal - channels, unsent := recv.broadcastDone(signal) + channels, unsent := recv.broadcast( + signal, + recv.broadcastDone, + recv.broadcastWait, + ) if unsent != 0 { return &unsentSignalError{ @@ -85,6 +196,21 @@ func (recv *signalReceivers) Broadcast(signal ShutdownSignal) error { return nil } +func (recv *signalReceivers) broadcast( + signal ShutdownSignal, + anchors ...func(ShutdownSignal) (int, int), +) (int, int) { + var channels, unsent int + + for _, anchor := range anchors { + c, u := anchor(signal) + channels += c + unsent += u + } + + return channels, unsent +} + func (recv *signalReceivers) broadcastDone(signal ShutdownSignal) (int, int) { var unsent int @@ -99,6 +225,20 @@ func (recv *signalReceivers) broadcastDone(signal ShutdownSignal) (int, int) { return len(recv.done), unsent } +func (recv *signalReceivers) broadcastWait(signal ShutdownSignal) (int, int) { + var unsent int + + for _, reader := range recv.wait { + select { + case reader <- signal: + default: + unsent++ + } + } + + return len(recv.wait), unsent +} + type unsentSignalError struct { Signal ShutdownSignal Unsent int diff --git a/signal_test.go b/signal_test.go index 481b74ec6..527213244 100644 --- a/signal_test.go +++ b/signal_test.go @@ -21,10 +21,13 @@ package fx import ( - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "context" + "os" "syscall" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func assertUnsentSignalError( @@ -44,22 +47,71 @@ func assertUnsentSignalError( func TestSignal(t *testing.T) { t.Parallel() - recv := newSignalReceivers() - a := recv.Done() - _ = recv.Done() // we never listen on this + t.Run("Done", func(t *testing.T) { + recv := newSignalReceivers() + a := recv.Done() + _ = recv.Done() // we never listen on this - expected := ShutdownSignal{ - Signal: syscall.SIGTERM, - } + expected := ShutdownSignal{ + Signal: syscall.SIGTERM, + } - require.NoError(t, recv.Broadcast(expected), "first broadcast should succeed") + require.NoError(t, recv.Broadcast(expected), "first broadcast should succeed") - assertUnsentSignalError(t, recv.Broadcast(expected), &unsentSignalError{ - Signal: expected, - Total: 2, - Unsent: 2, + assertUnsentSignalError(t, recv.Broadcast(expected), &unsentSignalError{ + Signal: expected, + Total: 2, + Unsent: 2, + }) + + assert.Equal(t, expected.Signal, <-a) + assert.Equal(t, expected.Signal, <-recv.Done(), "expect cached signal") }) - assert.Equal(t, expected.Signal, <-a) - assert.Equal(t, expected.Signal, <-recv.Done(), "expect cached signal") + t.Run("signal notify relayer", func(t *testing.T) { + t.Parallel() + t.Run("start and stop", func(t *testing.T) { + t.Parallel() + t.Run("timeout", func(t *testing.T) { + recv := newSignalReceivers() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + recv.Start(ctx) + timeoutCtx, cancel := context.WithTimeout(context.Background(), 0) + defer cancel() + err := recv.Stop(timeoutCtx) + require.ErrorIs(t, err, context.DeadlineExceeded) + }) + t.Run("no error", func(t *testing.T) { + recv := newSignalReceivers() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + recv.Start(ctx) + recv.Start(ctx) // should be a no-op if already running + require.NoError(t, recv.Stop(ctx)) + }) + t.Run("notify", func(t *testing.T) { + stub := make(chan os.Signal) + recv := newSignalReceivers() + recv.notify = func(ch chan<- os.Signal, _ ...os.Signal) { + go func() { + for sig := range stub { + ch <- sig + } + }() + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + recv.Start(ctx) + stub <- syscall.SIGTERM + stub <- syscall.SIGTERM + require.Equal(t, syscall.SIGTERM, <-recv.Done()) + require.Equal(t, syscall.SIGTERM, <-recv.Done()) + sig := <-recv.Wait() + require.Equal(t, syscall.SIGTERM, sig.Signal) + require.NoError(t, recv.Stop(ctx)) + close(stub) + }) + }) + }) }