diff --git a/acme/autocert/autocert.go b/acme/autocert/autocert.go index ca558e70c9..1858184e81 100644 --- a/acme/autocert/autocert.go +++ b/acme/autocert/autocert.go @@ -458,7 +458,7 @@ func (m *Manager) cert(ctx context.Context, ck certKey) (*tls.Certificate, error leaf: cert.Leaf, } m.state[ck] = s - go m.renew(ck, s.key, s.leaf.NotAfter) + go m.startRenew(ck, s.key, s.leaf.NotAfter) return cert, nil } @@ -584,8 +584,9 @@ func (m *Manager) createCert(ctx context.Context, ck certKey) (*tls.Certificate, if err != nil { // Remove the failed state after some time, // making the manager call createCert again on the following TLS hello. + didRemove := testDidRemoveState // The lifetime of this timer is untracked, so copy mutable local state to avoid races. time.AfterFunc(createCertRetryAfter, func() { - defer testDidRemoveState(ck) + defer didRemove(ck) m.stateMu.Lock() defer m.stateMu.Unlock() // Verify the state hasn't changed and it's still invalid @@ -603,7 +604,7 @@ func (m *Manager) createCert(ctx context.Context, ck certKey) (*tls.Certificate, } state.cert = der state.leaf = leaf - go m.renew(ck, state.key, state.leaf.NotAfter) + go m.startRenew(ck, state.key, state.leaf.NotAfter) return state.tlscert() } @@ -893,7 +894,7 @@ func httpTokenCacheKey(tokenPath string) string { return path.Base(tokenPath) + "+http-01" } -// renew starts a cert renewal timer loop, one per domain. +// startRenew starts a cert renewal timer loop, one per domain. // // The loop is scheduled in two cases: // - a cert was fetched from cache for the first time (wasn't in m.state) @@ -901,7 +902,7 @@ func httpTokenCacheKey(tokenPath string) string { // // The key argument is a certificate private key. // The exp argument is the cert expiration time (NotAfter). -func (m *Manager) renew(ck certKey, key crypto.Signer, exp time.Time) { +func (m *Manager) startRenew(ck certKey, key crypto.Signer, exp time.Time) { m.renewalMu.Lock() defer m.renewalMu.Unlock() if m.renewal[ck] != nil { diff --git a/acme/autocert/renewal.go b/acme/autocert/renewal.go index 665f870dcd..0df7da78a6 100644 --- a/acme/autocert/renewal.go +++ b/acme/autocert/renewal.go @@ -21,8 +21,9 @@ type domainRenewal struct { ck certKey key crypto.Signer - timerMu sync.Mutex - timer *time.Timer + timerMu sync.Mutex + timer *time.Timer + timerClose chan struct{} // if non-nil, renew closes this channel (and nils out the timer fields) instead of running } // start starts a cert renewal timer at the time @@ -38,16 +39,28 @@ func (dr *domainRenewal) start(exp time.Time) { dr.timer = time.AfterFunc(dr.next(exp), dr.renew) } -// stop stops the cert renewal timer. -// If the timer is already stopped, calling stop is a noop. +// stop stops the cert renewal timer and waits for any in-flight calls to renew +// to complete. If the timer is already stopped, calling stop is a noop. func (dr *domainRenewal) stop() { dr.timerMu.Lock() defer dr.timerMu.Unlock() - if dr.timer == nil { - return + for { + if dr.timer == nil { + return + } + if dr.timer.Stop() { + dr.timer = nil + return + } else { + // dr.timer fired, and we acquired dr.timerMu before the renew callback did. + // (We know this because otherwise the renew callback would have reset dr.timer!) + timerClose := make(chan struct{}) + dr.timerClose = timerClose + dr.timerMu.Unlock() + <-timerClose + dr.timerMu.Lock() + } } - dr.timer.Stop() - dr.timer = nil } // renew is called periodically by a timer. @@ -55,7 +68,9 @@ func (dr *domainRenewal) stop() { func (dr *domainRenewal) renew() { dr.timerMu.Lock() defer dr.timerMu.Unlock() - if dr.timer == nil { + if dr.timerClose != nil { + close(dr.timerClose) + dr.timer, dr.timerClose = nil, nil return } @@ -67,8 +82,8 @@ func (dr *domainRenewal) renew() { next = renewJitter / 2 next += time.Duration(pseudoRand.int63n(int64(next))) } - dr.timer = time.AfterFunc(next, dr.renew) testDidRenewLoop(next, err) + dr.timer = time.AfterFunc(next, dr.renew) } // updateState locks and replaces the relevant Manager.state item with the given diff --git a/acme/autocert/renewal_test.go b/acme/autocert/renewal_test.go index e5f48ffbb8..ffe4af2a5c 100644 --- a/acme/autocert/renewal_test.go +++ b/acme/autocert/renewal_test.go @@ -61,11 +61,23 @@ func TestRenewFromCache(t *testing.T) { // verify the renewal happened defer func() { + // Stop the timers that read and execute testDidRenewLoop before restoring it. + // Otherwise the timer callback may race with the deferred write. + man.stopRenew() testDidRenewLoop = func(next time.Duration, err error) {} }() - done := make(chan struct{}) + renewed := make(chan bool, 1) testDidRenewLoop = func(next time.Duration, err error) { - defer close(done) + defer func() { + select { + case renewed <- true: + default: + // The renewal timer uses a random backoff. If the first renewal fails for + // some reason, we could end up with multiple calls here before the test + // stops the timer. + } + }() + if err != nil { t.Errorf("testDidRenewLoop: %v", err) } @@ -81,7 +93,8 @@ func TestRenewFromCache(t *testing.T) { after := time.Now().Add(future) tlscert, err := man.cacheGet(context.Background(), exampleCertKey) if err != nil { - t.Fatalf("man.cacheGet: %v", err) + t.Errorf("man.cacheGet: %v", err) + return } if !tlscert.Leaf.NotAfter.After(after) { t.Errorf("cache leaf.NotAfter = %v; want > %v", tlscert.Leaf.NotAfter, after) @@ -92,11 +105,13 @@ func TestRenewFromCache(t *testing.T) { defer man.stateMu.Unlock() s := man.state[exampleCertKey] if s == nil { - t.Fatalf("m.state[%q] is nil", exampleCertKey) + t.Errorf("m.state[%q] is nil", exampleCertKey) + return } tlscert, err = s.tlscert() if err != nil { - t.Fatalf("s.tlscert: %v", err) + t.Errorf("s.tlscert: %v", err) + return } if !tlscert.Leaf.NotAfter.After(after) { t.Errorf("state leaf.NotAfter = %v; want > %v", tlscert.Leaf.NotAfter, after) @@ -108,13 +123,7 @@ func TestRenewFromCache(t *testing.T) { if _, err := man.GetCertificate(hello); err != nil { t.Fatal(err) } - - // wait for renew loop - select { - case <-time.After(10 * time.Second): - t.Fatal("renew took too long to occur") - case <-done: - } + <-renewed } func TestRenewFromCacheAlreadyRenewed(t *testing.T) { @@ -159,11 +168,23 @@ func TestRenewFromCacheAlreadyRenewed(t *testing.T) { // verify the renewal accepted the newer cached cert defer func() { + // Stop the timers that read and execute testDidRenewLoop before restoring it. + // Otherwise the timer callback may race with the deferred write. + man.stopRenew() testDidRenewLoop = func(next time.Duration, err error) {} }() - done := make(chan struct{}) + renewed := make(chan bool, 1) testDidRenewLoop = func(next time.Duration, err error) { - defer close(done) + defer func() { + select { + case renewed <- true: + default: + // The renewal timer uses a random backoff. If the first renewal fails for + // some reason, we could end up with multiple calls here before the test + // stops the timer. + } + }() + if err != nil { t.Errorf("testDidRenewLoop: %v", err) } @@ -177,7 +198,8 @@ func TestRenewFromCacheAlreadyRenewed(t *testing.T) { // ensure the cached cert was not modified tlscert, err := man.cacheGet(context.Background(), exampleCertKey) if err != nil { - t.Fatalf("man.cacheGet: %v", err) + t.Errorf("man.cacheGet: %v", err) + return } if !tlscert.Leaf.NotAfter.Equal(newLeaf.NotAfter) { t.Errorf("cache leaf.NotAfter = %v; want == %v", tlscert.Leaf.NotAfter, newLeaf.NotAfter) @@ -188,30 +210,22 @@ func TestRenewFromCacheAlreadyRenewed(t *testing.T) { defer man.stateMu.Unlock() s := man.state[exampleCertKey] if s == nil { - t.Fatalf("m.state[%q] is nil", exampleCertKey) + t.Errorf("m.state[%q] is nil", exampleCertKey) + return } stateKey := s.key.Public().(*ecdsa.PublicKey) if !stateKey.Equal(newLeaf.PublicKey) { - t.Fatal("state key was not updated from cache") + t.Error("state key was not updated from cache") + return } tlscert, err = s.tlscert() if err != nil { - t.Fatalf("s.tlscert: %v", err) + t.Errorf("s.tlscert: %v", err) + return } if !tlscert.Leaf.NotAfter.Equal(newLeaf.NotAfter) { t.Errorf("state leaf.NotAfter = %v; want == %v", tlscert.Leaf.NotAfter, newLeaf.NotAfter) } - - // verify the private key is replaced in the renewal state - r := man.renewal[exampleCertKey] - if r == nil { - t.Fatalf("m.renewal[%q] is nil", exampleCertKey) - } - renewalKey := r.key.Public().(*ecdsa.PublicKey) - if !renewalKey.Equal(newLeaf.PublicKey) { - t.Fatal("renewal private key was not updated from cache") - } - } // assert the expiring cert is returned from state @@ -225,21 +239,31 @@ func TestRenewFromCacheAlreadyRenewed(t *testing.T) { } // trigger renew - go man.renew(exampleCertKey, s.key, s.leaf.NotAfter) - - // wait for renew loop - select { - case <-time.After(10 * time.Second): - t.Fatal("renew took too long to occur") - case <-done: - // assert the new cert is returned from state after renew - hello := clientHelloInfo(exampleDomain, algECDSA) - tlscert, err := man.GetCertificate(hello) - if err != nil { - t.Fatal(err) + man.startRenew(exampleCertKey, s.key, s.leaf.NotAfter) + <-renewed + func() { + man.renewalMu.Lock() + defer man.renewalMu.Unlock() + + // verify the private key is replaced in the renewal state + r := man.renewal[exampleCertKey] + if r == nil { + t.Errorf("m.renewal[%q] is nil", exampleCertKey) + return } - if !newLeaf.NotAfter.Equal(tlscert.Leaf.NotAfter) { - t.Errorf("state leaf.NotAfter = %v; want == %v", tlscert.Leaf.NotAfter, newLeaf.NotAfter) + renewalKey := r.key.Public().(*ecdsa.PublicKey) + if !renewalKey.Equal(newLeaf.PublicKey) { + t.Error("renewal private key was not updated from cache") } + }() + + // assert the new cert is returned from state after renew + hello = clientHelloInfo(exampleDomain, algECDSA) + tlscert, err = man.GetCertificate(hello) + if err != nil { + t.Fatal(err) + } + if !newLeaf.NotAfter.Equal(tlscert.Leaf.NotAfter) { + t.Errorf("state leaf.NotAfter = %v; want == %v", tlscert.Leaf.NotAfter, newLeaf.NotAfter) } }