From 9f1a52001789d8da640edfe1dbf145bbf3d1ea50 Mon Sep 17 00:00:00 2001 From: "Carmen J. Cabezas" Date: Thu, 25 Aug 2022 09:09:25 +0200 Subject: [PATCH] [VIAM-660] Fix memory leaks on kratos anti-password brute forcing protection (#44) --- selfservice/strategy/password/login.go | 49 ++++++++++++-------------- 1 file changed, 23 insertions(+), 26 deletions(-) diff --git a/selfservice/strategy/password/login.go b/selfservice/strategy/password/login.go index 7922d242e74a..039c7023e003 100644 --- a/selfservice/strategy/password/login.go +++ b/selfservice/strategy/password/login.go @@ -7,9 +7,9 @@ import ( "context" "encoding/json" "net/http" - "sync" "time" + "github.com/Vonage/go-viam-utils/utils" errors2 "github.com/ory/kratos/schema/errors" "github.com/ory/kratos/selfservice/flowhelpers" @@ -57,8 +57,21 @@ type passCheckStatus struct { numTries uint } -var passCheckCache = make(map[string]passCheckStatus, 10000) -var cacheMutex sync.RWMutex +// This is the validity checker for cache elements. A element is valid if it hasn't expired yet. +func inWindow(v interface{}) bool { + return v.(*passCheckStatus).checkExpiresAt.After(time.Now()) +} + +// passCheckCache is ExpiringCache[string, passCheckStatus) +var passCheckCache = utils.NewCheckedExpiringCache(inWindow, delayReset, 10000).Start() + +func passCheckCacheGet(id string) (passCheckStatus, bool) { + asInt, exists := passCheckCache.SyncGet(id) + if exists { + return asInt.(passCheckStatus), true + } + return passCheckStatus{}, false +} func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, identityID uuid.UUID) (i *identity.Identity, err error) { if err := login.CheckAAL(f, identity.AuthenticatorAssuranceLevel1); err != nil { @@ -81,21 +94,11 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, return nil, s.handleLoginError(w, r, f, &p, err) } - cacheMutex.RLock() - lastCheckResult, exists := passCheckCache[p.Identifier] - cacheMutex.RUnlock() - if exists && lastCheckResult.checkExpiresAt.Before(time.Now()) { - cacheMutex.Lock() - delete(passCheckCache, p.Identifier) - cacheMutex.Unlock() - exists = false - } + lastCheckResult, exists := passCheckCacheGet(p.Identifier) if exists && lastCheckResult.numTries >= delayAfterNumTries { expireAt := lastCheckResult.checkExpiresAt time.Sleep(delayTry) - cacheMutex.RLock() - lastCheckResult, exists = passCheckCache[p.Identifier] - cacheMutex.RUnlock() + lastCheckResult, exists = passCheckCacheGet(p.Identifier) if exists && !expireAt.Equal(lastCheckResult.checkExpiresAt) { time.Sleep(delayTry) // Note that this will probably mean the request will time out. Too bad, so sad. } @@ -104,12 +107,10 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, id, hashedPwd, exists, cacheEnabled := s.d.CheckPwdCache(p.Identifier) invalidUserDelay := func() { - cacheMutex.Lock() - passCheckCache[p.Identifier] = passCheckStatus{ + passCheckCache.SyncSet(p.Identifier, passCheckStatus{ checkExpiresAt: time.Now().Add(delayReset), numTries: lastCheckResult.numTries + 1, - } - cacheMutex.Unlock() + }) time.Sleep(x.RandomDelay(s.d.Config().HasherArgon2(r.Context()).ExpectedDuration, s.d.Config().HasherArgon2(r.Context()).ExpectedDeviation)) i = nil err = s.handleLoginError(w, r, f, &p, errors.WithStack(errors2.NewInvalidCredentialsError())) @@ -133,12 +134,10 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, } if err := hash.Compare(r.Context(), []byte(p.Password), []byte(hashedPwd)); err != nil { - cacheMutex.Lock() - passCheckCache[p.Identifier] = passCheckStatus{ + passCheckCache.SyncSet(p.Identifier, passCheckStatus{ checkExpiresAt: time.Now().Add(delayReset), numTries: lastCheckResult.numTries + 1, - } - cacheMutex.Unlock() + }) return nil, s.handleLoginError(w, r, f, &p, errors.WithStack(errors2.NewInvalidCredentialsError())) } @@ -156,9 +155,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, if err = s.d.LoginFlowPersister().UpdateLoginFlow(r.Context(), f); err != nil { return nil, s.handleLoginError(w, r, f, &p, errors.WithStack(herodot.ErrInternalServerError.WithReason("Could not update flow").WithDebug(err.Error()))) } - cacheMutex.Lock() - delete(passCheckCache, p.Identifier) - cacheMutex.Unlock() + passCheckCache.SyncRemove(p.Identifier) if i == nil { userId, _ := uuid.FromString(id) i, err = s.d.PrivilegedIdentityPool().GetIdentity(r.Context(), userId, identity.ExpandDefault)