Skip to content

Commit

Permalink
Forced rotation: Remove cached JWT-SVIDs using tainted keys (#5565)
Browse files Browse the repository at this point in the history
* Remove from the agent cache the JWT-SVIDs using tainted keys

Signed-off-by: Agustín Martínez Fayó <[email protected]>
  • Loading branch information
amartinezfayo authored Oct 16, 2024
1 parent 3715714 commit b80bf4e
Show file tree
Hide file tree
Showing 9 changed files with 368 additions and 53 deletions.
73 changes: 69 additions & 4 deletions pkg/agent/manager/cache/jwt_cache.go
Original file line number Diff line number Diff line change
@@ -1,28 +1,40 @@
package cache

import (
"context"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"io"
"sort"
"sync"

"github.com/go-jose/go-jose/v4/jwt"
"github.com/sirupsen/logrus"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/spire/pkg/agent/client"
"github.com/spiffe/spire/pkg/common/jwtsvid"
"github.com/spiffe/spire/pkg/common/telemetry"
"github.com/spiffe/spire/pkg/common/telemetry/agent"
)

type JWTSVIDCache struct {
mu sync.Mutex
svids map[string]*client.JWTSVID
log logrus.FieldLogger
metrics telemetry.Metrics
mu sync.RWMutex
svids map[string]*client.JWTSVID
}

func (c *JWTSVIDCache) CountJWTSVIDs() int {
return len(c.svids)
}

func NewJWTSVIDCache() *JWTSVIDCache {
func NewJWTSVIDCache(log logrus.FieldLogger, metrics telemetry.Metrics) *JWTSVIDCache {
return &JWTSVIDCache{
svids: make(map[string]*client.JWTSVID),
metrics: metrics,
log: log,
svids: make(map[string]*client.JWTSVID),
}
}

Expand All @@ -43,6 +55,59 @@ func (c *JWTSVIDCache) SetJWTSVID(spiffeID spiffeid.ID, audience []string, svid
c.svids[key] = svid
}

func (c *JWTSVIDCache) TaintJWTSVIDs(ctx context.Context, taintedJWTAuthorities map[string]struct{}) {
c.mu.Lock()
defer c.mu.Unlock()

counter := telemetry.StartCall(c.metrics, telemetry.CacheManager, agent.CacheTypeWorkload, telemetry.ProcessTaintedJWTSVIDs)
defer counter.Done(nil)

removedKeyIDs := make(map[string]int)
totalCount := 0
for key, jwtSVID := range c.svids {
keyID, err := getKeyIDFromSVIDToken(jwtSVID.Token)
if err != nil {
c.log.WithError(err).Error("Could not get key ID from cached JWT-SVID")
continue
}
if _, tainted := taintedJWTAuthorities[keyID]; tainted {
delete(c.svids, key)
removedKeyIDs[keyID]++
totalCount++
}
select {
case <-ctx.Done():
c.log.WithError(ctx.Err()).Warn("Context cancelled, exiting process of tainting JWT-SVIDs in cache")
return
default:
}
}
for keyID, count := range removedKeyIDs {
c.log.WithField(telemetry.JWTAuthorityKeyIDs, keyID).
WithField(telemetry.TaintedJWTSVIDs, count).
Info("JWT-SVIDs were removed from the JWT cache because they were issued by a tainted authority")
}
agent.AddCacheManagerTaintedJWTSVIDsSample(c.metrics, agent.CacheTypeWorkload, float32(totalCount))
}

func getKeyIDFromSVIDToken(svidToken string) (string, error) {
token, err := jwt.ParseSigned(svidToken, jwtsvid.AllowedSignatureAlgorithms)
if err != nil {
return "", fmt.Errorf("failed to parse JWT-SVID: %w", err)
}

if len(token.Headers) != 1 {
return "", fmt.Errorf("malformed JWT-SVID: expected a single token header; got %d", len(token.Headers))
}

keyID := token.Headers[0].KeyID
if keyID == "" {
return "", errors.New("missing key ID in token header of minted JWT-SVID")
}

return keyID, nil
}

func jwtSVIDKey(spiffeID spiffeid.ID, audience []string) string {
h := sha256.New()

Expand Down
207 changes: 201 additions & 6 deletions pkg/agent/manager/cache/jwt_cache_test.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,33 @@
package cache

import (
"context"
"testing"
"time"

"github.com/hashicorp/go-metrics"
"github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/spire/pkg/agent/client"
"github.com/spiffe/spire/pkg/common/telemetry"
"github.com/spiffe/spire/pkg/common/telemetry/agent"
"github.com/spiffe/spire/test/fakes/fakemetrics"
"github.com/spiffe/spire/test/spiretest"
"github.com/stretchr/testify/assert"
)

func TestJWTSVIDCacheBasic(t *testing.T) {
func TestJWTSVIDCache(t *testing.T) {
now := time.Now()
expected := &client.JWTSVID{Token: "X", IssuedAt: now, ExpiresAt: now.Add(time.Second)}
tok1 := "eyJhbGciOiJFUzI1NiIsImtpZCI6ImRaRGZZaXcxdUd6TXdkTVlITDdGRVl5SzhIT0tLd0xYIiwidHlwIjoiSldUIn0.eyJhdWQiOlsidGVzdC1hdWRpZW5jZSJdLCJleHAiOjE3MjQzNjU3MzEsImlhdCI6MTcyNDI3OTQwNywic3ViIjoic3BpZmZlOi8vZXhhbXBsZS5vcmcvYWdlbnQvZGJ1c2VyIn0.dFr-oWhm5tK0bBuVXt-sGESM5l7hhoY-Gtt5DkuFoJL5Y9d4ZfmicCvUCjL4CqDB3BO_cPqmFfrO7H7pxQbGLg"
tok2 := "eyJhbGciOiJFUzI1NiIsImtpZCI6ImNKMXI5TVY4OTZTWXBMY0RMUjN3Q29QRHprTXpkN25tIiwidHlwIjoiSldUIn0.eyJhdWQiOlsidGVzdC1hdWRpZW5jZSJdLCJleHAiOjE3Mjg1NzEwMjUsImlhdCI6MTcyODU3MDcyNSwic3ViIjoic3BpZmZlOi8vZXhhbXBsZS5vcmcvYWdlbnQvZGJ1c2VyIn0.1YnDj7nknwIHEuNKEN0cNypXKS4SUeILXlNOsOs2XElHzfKhhDcl0sYKYtQc1Itf6cygz9C16VOQ_Yjoos2Qfg"
jwtSVID := &client.JWTSVID{Token: tok1, IssuedAt: now, ExpiresAt: now.Add(time.Second)}
jwtSVID2 := &client.JWTSVID{Token: tok2, IssuedAt: now, ExpiresAt: now.Add(time.Second)}

cache := NewJWTSVIDCache()
fakeMetrics := fakemetrics.New()
log, logHook := test.NewNullLogger()
log.Level = logrus.DebugLevel
cache := NewJWTSVIDCache(log, fakeMetrics)

spiffeID := spiffeid.RequireFromString("spiffe://example.org/blog")

Expand All @@ -23,18 +37,199 @@ func TestJWTSVIDCacheBasic(t *testing.T) {
assert.Nil(t, actual)

// JWT is cached
cache.SetJWTSVID(spiffeID, []string{"bar"}, expected)
cache.SetJWTSVID(spiffeID, []string{"bar"}, jwtSVID)
actual, ok = cache.GetJWTSVID(spiffeID, []string{"bar"})
assert.True(t, ok)
assert.Equal(t, expected, actual)
assert.Equal(t, jwtSVID, actual)

// Test tainting of JWt-SVIDs
ctx := context.Background()
keyID1 := "dZDfYiw1uGzMwdMYHL7FEYyK8HOKKwLX"
keyID2 := "cJ1r9MV896SYpLcDLR3wCoPDzkMzd7nm"
for _, tt := range []struct {
name string
taintedKeyIDs map[string]struct{}
setJWTSVIDsCached func(cache *JWTSVIDCache)
expectLogs []spiretest.LogEntry
expectMetrics []fakemetrics.MetricItem
}{
{
name: "one authority tainted, one JWT-SVID",
taintedKeyIDs: map[string]struct{}{keyID1: {}},
setJWTSVIDsCached: func(cache *JWTSVIDCache) {
cache.SetJWTSVID(spiffeID, []string{"audience-1"}, jwtSVID)
},
expectLogs: []spiretest.LogEntry{
{
Level: logrus.InfoLevel,
Message: "JWT-SVIDs were removed from the JWT cache because they were issued by a tainted authority",
Data: logrus.Fields{
telemetry.TaintedJWTSVIDs: "1",
telemetry.JWTAuthorityKeyIDs: keyID1,
},
},
},
expectMetrics: []fakemetrics.MetricItem{
{
Type: fakemetrics.AddSampleType,
Key: []string{telemetry.CacheManager, telemetry.TaintedJWTSVIDs, agent.CacheTypeWorkload},
Val: 1,
},
{
Type: fakemetrics.IncrCounterWithLabelsType,
Key: []string{telemetry.CacheManager, agent.CacheTypeWorkload, telemetry.ProcessTaintedJWTSVIDs},
Val: 1,
Labels: []metrics.Label{{Name: "status", Value: "OK"}},
},
{
Type: fakemetrics.MeasureSinceWithLabelsType,
Key: []string{telemetry.CacheManager, agent.CacheTypeWorkload, telemetry.ProcessTaintedJWTSVIDs, telemetry.ElapsedTime},
Val: 0,
Labels: []metrics.Label{{Name: "status", Value: "OK"}},
},
},
},
{
name: "one authority tainted, multiple JWT-SVIDs",
taintedKeyIDs: map[string]struct{}{keyID1: {}},
setJWTSVIDsCached: func(cache *JWTSVIDCache) {
cache.SetJWTSVID(spiffeID, []string{"audience-1"}, jwtSVID)
cache.SetJWTSVID(spiffeID, []string{"audience-2"}, jwtSVID)
},
expectLogs: []spiretest.LogEntry{
{
Level: logrus.InfoLevel,
Message: "JWT-SVIDs were removed from the JWT cache because they were issued by a tainted authority",
Data: logrus.Fields{
telemetry.TaintedJWTSVIDs: "2",
telemetry.JWTAuthorityKeyIDs: keyID1,
},
},
},
expectMetrics: []fakemetrics.MetricItem{
{
Type: fakemetrics.AddSampleType,
Key: []string{telemetry.CacheManager, telemetry.TaintedJWTSVIDs, agent.CacheTypeWorkload},
Val: 2,
},
{
Type: fakemetrics.IncrCounterWithLabelsType,
Key: []string{telemetry.CacheManager, agent.CacheTypeWorkload, telemetry.ProcessTaintedJWTSVIDs},
Val: 1,
Labels: []metrics.Label{{Name: "status", Value: "OK"}},
},
{
Type: fakemetrics.MeasureSinceWithLabelsType,
Key: []string{telemetry.CacheManager, agent.CacheTypeWorkload, telemetry.ProcessTaintedJWTSVIDs, telemetry.ElapsedTime},
Val: 0,
Labels: []metrics.Label{{Name: "status", Value: "OK"}},
},
},
},
{
name: "multiple authorities tainted, multiple JWT-SVIDs",
taintedKeyIDs: map[string]struct{}{keyID1: {}, keyID2: {}},
setJWTSVIDsCached: func(cache *JWTSVIDCache) {
cache.SetJWTSVID(spiffeID, []string{"audience-1"}, jwtSVID)
cache.SetJWTSVID(spiffeID, []string{"audience-2"}, jwtSVID)
cache.SetJWTSVID(spiffeID, []string{"audience-3"}, jwtSVID2)
},
expectLogs: []spiretest.LogEntry{
{
Level: logrus.InfoLevel,
Message: "JWT-SVIDs were removed from the JWT cache because they were issued by a tainted authority",
Data: logrus.Fields{
telemetry.TaintedJWTSVIDs: "2",
telemetry.JWTAuthorityKeyIDs: keyID1,
},
},
{
Level: logrus.InfoLevel,
Message: "JWT-SVIDs were removed from the JWT cache because they were issued by a tainted authority",
Data: logrus.Fields{
telemetry.TaintedJWTSVIDs: "1",
telemetry.JWTAuthorityKeyIDs: keyID2,
},
},
},
expectMetrics: []fakemetrics.MetricItem{
{
Type: fakemetrics.AddSampleType,
Key: []string{telemetry.CacheManager, telemetry.TaintedJWTSVIDs, agent.CacheTypeWorkload},
Val: 3,
},
{
Type: fakemetrics.IncrCounterWithLabelsType,
Key: []string{telemetry.CacheManager, agent.CacheTypeWorkload, telemetry.ProcessTaintedJWTSVIDs},
Val: 1,
Labels: []metrics.Label{{Name: "status", Value: "OK"}},
},
{
Type: fakemetrics.MeasureSinceWithLabelsType,
Key: []string{telemetry.CacheManager, agent.CacheTypeWorkload, telemetry.ProcessTaintedJWTSVIDs, telemetry.ElapsedTime},
Val: 0,
Labels: []metrics.Label{{Name: "status", Value: "OK"}},
},
},
},
{
name: "none of the authorities tainted is in cache",
taintedKeyIDs: map[string]struct{}{"not-cached-1": {}, "not-cached-2": {}},
setJWTSVIDsCached: func(cache *JWTSVIDCache) {
cache.SetJWTSVID(spiffeID, []string{"audience-1"}, jwtSVID)
cache.SetJWTSVID(spiffeID, []string{"audience-2"}, jwtSVID)
cache.SetJWTSVID(spiffeID, []string{"audience-3"}, jwtSVID2)
},
expectMetrics: []fakemetrics.MetricItem{
{
Type: fakemetrics.AddSampleType,
Key: []string{telemetry.CacheManager, telemetry.TaintedJWTSVIDs, agent.CacheTypeWorkload},
Val: 0,
},
{
Type: fakemetrics.IncrCounterWithLabelsType,
Key: []string{telemetry.CacheManager, agent.CacheTypeWorkload, telemetry.ProcessTaintedJWTSVIDs},
Val: 1,
Labels: []metrics.Label{{Name: "status", Value: "OK"}},
},
{
Type: fakemetrics.MeasureSinceWithLabelsType,
Key: []string{telemetry.CacheManager, agent.CacheTypeWorkload, telemetry.ProcessTaintedJWTSVIDs, telemetry.ElapsedTime},
Val: 0,
Labels: []metrics.Label{{Name: "status", Value: "OK"}},
},
},
},
} {
tt := tt
t.Run(tt.name, func(t *testing.T) {
cache := NewJWTSVIDCache(log, fakeMetrics)
if tt.setJWTSVIDsCached != nil {
tt.setJWTSVIDsCached(cache)
}

// Remove tainted authority, should not be cached anymore
cache.TaintJWTSVIDs(ctx, tt.taintedKeyIDs)
actual, ok = cache.GetJWTSVID(spiffeID, []string{"bar"})
assert.False(t, ok)
assert.Nil(t, actual)

spiretest.AssertLogsAnyOrder(t, logHook.AllEntries(), tt.expectLogs)
assert.Equal(t, tt.expectMetrics, fakeMetrics.AllMetrics())
resetLogsAndMetrics(logHook, fakeMetrics)
})
}
}

func TestJWTSVIDCacheKeyHashing(t *testing.T) {
spiffeID := spiffeid.RequireFromString("spiffe://example.org/blog")
now := time.Now()
expected := &client.JWTSVID{Token: "X", IssuedAt: now, ExpiresAt: now.Add(time.Second)}

cache := NewJWTSVIDCache()
fakeMetrics := fakemetrics.New()
log, _ := test.NewNullLogger()
log.Level = logrus.DebugLevel
cache := NewJWTSVIDCache(log, fakeMetrics)
cache.SetJWTSVID(spiffeID, []string{"ab", "cd"}, expected)

// JWT is cached
Expand Down
11 changes: 6 additions & 5 deletions pkg/agent/manager/cache/lru_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/spire/pkg/common/backoff"
"github.com/spiffe/spire/pkg/common/telemetry"
"github.com/spiffe/spire/pkg/common/telemetry/agent"
agentmetrics "github.com/spiffe/spire/pkg/common/telemetry/agent"
"github.com/spiffe/spire/pkg/common/x509util"
"github.com/spiffe/spire/proto/spire/common"
Expand Down Expand Up @@ -42,7 +43,7 @@ type UpdateEntries struct {
TaintedX509Authorities []string

// TaintedJWTAuthorities is a set of all tainted JWT authorities notified by the server.
TaintedJWTAuthorities []string
TaintedJWTAuthorities map[string]struct{}

// RegistrationEntries is a set of all registration entries available to the
// agent, keyed by registration entry id.
Expand Down Expand Up @@ -156,7 +157,7 @@ func NewLRUCache(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundl

return &LRUCache{
BundleCache: NewBundleCache(trustDomain, bundle),
JWTSVIDCache: NewJWTSVIDCache(),
JWTSVIDCache: NewJWTSVIDCache(log, metrics),

log: log,
metrics: metrics,
Expand Down Expand Up @@ -635,7 +636,7 @@ func (c *LRUCache) notifyTaintedBatchProcessed() {

// processTaintedSVIDs identifies and removes tainted SVIDs from the cache that have been signed by the given tainted authorities.
func (c *LRUCache) processTaintedSVIDs(entryIDs []string, taintedX509Authorities []*x509.Certificate) {
counter := telemetry.StartCall(c.metrics, telemetry.CacheManager, "", telemetry.ProcessTaintedSVIDs)
counter := telemetry.StartCall(c.metrics, telemetry.CacheManager, agent.CacheTypeWorkload, telemetry.ProcessTaintedX509SVIDs)
defer counter.Done(nil)

taintedSVIDs := 0
Expand Down Expand Up @@ -664,8 +665,8 @@ func (c *LRUCache) processTaintedSVIDs(entryIDs []string, taintedX509Authorities
}
}

agentmetrics.AddCacheManagerTaintedSVIDsSample(c.metrics, "", float32(taintedSVIDs))
c.log.WithField(telemetry.TaintedSVIDs, taintedSVIDs).Info("Tainted X.509 SVIDs")
agentmetrics.AddCacheManagerTaintedX509SVIDsSample(c.metrics, agentmetrics.CacheTypeWorkload, float32(taintedSVIDs))
c.log.WithField(telemetry.TaintedX509SVIDs, taintedSVIDs).Info("Tainted X.509 SVIDs")
}

// Notify subscriber of selector set only if all SVIDs for corresponding selector set are cached
Expand Down
Loading

0 comments on commit b80bf4e

Please sign in to comment.