From ddff7094ff556609a04f0983e2591d5fee3ce942 Mon Sep 17 00:00:00 2001 From: Gibheer Date: Tue, 17 Mar 2020 23:35:07 +0100 Subject: [PATCH] fix:add graceful shutdown to courier handler (#296) Courier would not stop with the provided Background handler. This changes the methods of Courier so that the graceful package can be used in the same way as the http endpoints can be used. Closes #295 --- cmd/daemon/serve.go | 4 ++-- courier/courier.go | 26 ++++++++++++++++++++------ courier/courier_test.go | 2 +- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/cmd/daemon/serve.go b/cmd/daemon/serve.go index 2745884677f3..6a86d80126e7 100644 --- a/cmd/daemon/serve.go +++ b/cmd/daemon/serve.go @@ -1,7 +1,6 @@ package daemon import ( - stdctx "context" "net/http" "strings" "sync" @@ -165,9 +164,10 @@ func sqa(cmd *cobra.Command, d driver.Driver) *metricsx.Service { func bgTasks(d driver.Driver, wg *sync.WaitGroup, cmd *cobra.Command, args []string) { defer wg.Done() - if err := d.Registry().Courier().Work(stdctx.Background()); err != nil { + if err := graceful.Graceful(d.Registry().Courier().Work, d.Registry().Courier().Shutdown); err != nil { d.Logger().WithError(err).Fatalf("Failed to run courier worker.") } + d.Logger().Println("courier worker was shutdown gracefully") } func ServeAll(d driver.Driver) func(cmd *cobra.Command, args []string) { diff --git a/courier/courier.go b/courier/courier.go index d8c5b6e501a6..b909501a424f 100644 --- a/courier/courier.go +++ b/courier/courier.go @@ -27,6 +27,9 @@ type ( dialer *gomail.Dialer d smtpDependencies c configuration.Provider + // graceful shutdown handling + ctx context.Context + shutdown context.CancelFunc } Provider interface { Courier() *Courier @@ -38,9 +41,12 @@ func NewSMTP(d smtpDependencies, c configuration.Provider) *Courier { sslSkipVerify, _ := strconv.ParseBool(uri.Query().Get("skip_ssl_verify")) password, _ := uri.User.Password() port, _ := strconv.ParseInt(uri.Port(), 10, 64) + ctx, cancel := context.WithCancel(context.Background()) return &Courier{ - d: d, - c: c, + d: d, + c: c, + ctx: ctx, + shutdown: cancel, dialer: &gomail.Dialer{ Host: uri.Hostname(), Port: int(port), @@ -82,20 +88,28 @@ func (m *Courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, e return message.ID, nil } -func (m *Courier) Work(ctx context.Context) error { +func (m *Courier) Work() error { errChan := make(chan error) defer close(errChan) - go m.watchMessages(ctx, errChan) + go m.watchMessages(m.ctx, errChan) select { - case <-ctx.Done(): - return ctx.Err() + case <-m.ctx.Done(): + if m.ctx.Err() == context.Canceled { + return nil + } + return m.ctx.Err() case err := <-errChan: return err } } +func (m *Courier) Shutdown(ctx context.Context) error { + m.shutdown() + return nil +} + func (m *Courier) watchMessages(ctx context.Context, errChan chan error) { for { if err := backoff.Retry(func() error { diff --git a/courier/courier_test.go b/courier/courier_test.go index 9de5c02f2555..dfd809208006 100644 --- a/courier/courier_test.go +++ b/courier/courier_test.go @@ -104,7 +104,7 @@ func TestSMTP(t *testing.T) { c := reg.Courier() go func() { - require.NoError(t, c.Work(context.Background())) + require.NoError(t, c.Work()) }() t.Run("case=queue messages", func(t *testing.T) {