diff --git a/src/IdentityServer/Services/Default/KeyManagement/KeyManager.cs b/src/IdentityServer/Services/Default/KeyManagement/KeyManager.cs index 94d1798a2..98c4e2423 100644 --- a/src/IdentityServer/Services/Default/KeyManagement/KeyManager.cs +++ b/src/IdentityServer/Services/Default/KeyManagement/KeyManager.cs @@ -12,6 +12,7 @@ using Duende.IdentityServer.Extensions; using Duende.IdentityServer.Internal; using System.Security.Cryptography; +using Duende.IdentityServer.Models; namespace Duende.IdentityServer.Services.KeyManagement; @@ -335,14 +336,13 @@ internal bool AreAllKeysWithinInitializationDuration(IEnumerable k return result; } - internal async Task> FilterAndDeleteRetiredKeysAsync(IEnumerable keys) + internal async Task> FilterAndDeleteRetiredKeysAsync(IEnumerable keys) { var retired = keys .Where(x => { - var age = _clock.GetAge(x.Created); - var isRetired = _options.KeyManagement.IsRetired(age); - return isRetired; + return (x != null) && + _options.KeyManagement.IsRetired(_clock.GetAge(x.Created)); }) .ToArray(); @@ -428,6 +428,9 @@ internal async Task> GetAllKeysFromStoreAsync(bool cac var protectedKeys = await _store.LoadKeysAsync(); if (protectedKeys != null && protectedKeys.Any()) { + // retired keys are those that are beyond inclusion, thus we act as if they don't exist. + protectedKeys = await FilterAndDeleteRetiredKeysAsync(protectedKeys); + var keys = protectedKeys.Select(x => { try @@ -459,9 +462,7 @@ internal async Task> GetAllKeysFromStoreAsync(bool cac _logger.LogTrace("Loaded keys from store: {kids}", ids.Aggregate((x, y) => $"{x},{y}")); } - // retired keys are those that are beyond inclusion, thus we act as if they don't exist. - keys = await FilterAndDeleteRetiredKeysAsync(keys); - + if (_logger.IsEnabled(LogLevel.Trace) && keys.Any()) { var ids = keys.Select(x => x.Id).ToArray(); diff --git a/test/IdentityServer.UnitTests/Services/Default/KeyManagement/KeyManagerTests.cs b/test/IdentityServer.UnitTests/Services/Default/KeyManagement/KeyManagerTests.cs index c50af9da2..523b18900 100644 --- a/test/IdentityServer.UnitTests/Services/Default/KeyManagement/KeyManagerTests.cs +++ b/test/IdentityServer.UnitTests/Services/Default/KeyManagement/KeyManagerTests.cs @@ -10,6 +10,7 @@ using Duende.IdentityServer.Configuration; using Duende.IdentityServer.Extensions; using Duende.IdentityServer.Internal; +using Duende.IdentityServer.Models; using Duende.IdentityServer.Services.KeyManagement; using FluentAssertions; using Microsoft.Extensions.Logging; @@ -34,7 +35,7 @@ public class KeyManagerTests public KeyManagerTests() { // just to speed up the tests - _options.KeyManagement.InitializationSynchronizationDelay = TimeSpan.FromSeconds(1); + _options.KeyManagement.InitializationSynchronizationDelay = TimeSpan.FromMilliseconds(1); _options.KeyManagement.SigningAlgorithms = new[] { _rsaOptions }; @@ -49,6 +50,12 @@ public KeyManagerTests() new TestIssuerNameService()); } + SerializedKey CreateSerializedKey(TimeSpan? age = null, string alg = "RS256", bool x509 = false) + { + var container = CreateKey(age, alg, x509); + return _mockKeyProtector.Protect(container); + } + KeyContainer CreateKey(TimeSpan? age = null, string alg = "RS256", bool x509 = false) { var key = _options.KeyManagement.CreateRsaSecurityKey(); @@ -69,7 +76,14 @@ string CreateAndStoreKey(TimeSpan? age = null) _mockKeyStore.Keys.Add(_mockKeyProtector.Protect(container)); return container.Id; } - + + string CreateAndStoreKeyThatCannotBeUnprotected(TimeSpan? age = null) + { + var container = CreateKey(age); + _mockKeyStore.Keys.Add(_mockKeyProtector.ProtectAndLoseDataProtectionKey(container)); + return container.Id; + } + string CreateCacheAndStoreKey(TimeSpan? age = null) { var container = CreateKey(age); @@ -458,12 +472,12 @@ public void AreAllKeysWithinInitializationDuration_for_older_keys_should_return_ [Fact] public async Task FilterRetiredKeys_should_filter_retired_keys() { - var key1 = CreateKey(_options.KeyManagement.KeyRetirementAge.Add(TimeSpan.FromSeconds(1))); - var key2 = CreateKey(_options.KeyManagement.KeyRetirementAge); - var key3 = CreateKey(_options.KeyManagement.KeyRetirementAge.Subtract(TimeSpan.FromSeconds(1))); - var key4 = CreateKey(_options.KeyManagement.PropagationTime.Add(TimeSpan.FromSeconds(1))); - var key5 = CreateKey(_options.KeyManagement.PropagationTime); - var key6 = CreateKey(_options.KeyManagement.PropagationTime.Subtract(TimeSpan.FromSeconds(1))); + var key1 = CreateSerializedKey(_options.KeyManagement.KeyRetirementAge.Add(TimeSpan.FromSeconds(1))); + var key2 = CreateSerializedKey(_options.KeyManagement.KeyRetirementAge); + var key3 = CreateSerializedKey(_options.KeyManagement.KeyRetirementAge.Subtract(TimeSpan.FromSeconds(1))); + var key4 = CreateSerializedKey(_options.KeyManagement.PropagationTime.Add(TimeSpan.FromSeconds(1))); + var key5 = CreateSerializedKey(_options.KeyManagement.PropagationTime); + var key6 = CreateSerializedKey(_options.KeyManagement.PropagationTime.Subtract(TimeSpan.FromSeconds(1))); var result = await _subject.FilterAndDeleteRetiredKeysAsync(new[] { key1, key2, key3, key4, key5, key6 }); @@ -594,6 +608,24 @@ public async Task GetKeysFromStoreAsync_should_filter_retired_keys() keys.Select(x => x.Id).Should().BeEquivalentTo(new[] { key1, key2, key3, key4 }); } + [Fact] + public async Task GetKeysFromStoreAsync_should_filter_retired_keys_that_cannot_be_unprotected() + { + var key1 = CreateAndStoreKey(TimeSpan.FromSeconds(10)); + var key2 = CreateAndStoreKey(TimeSpan.FromSeconds(5)); + var key3 = CreateAndStoreKey(-TimeSpan.FromSeconds(5)); + var key4 = CreateAndStoreKey(_options.KeyManagement.RotationInterval.Add(TimeSpan.FromSeconds(1))); + var key5 = CreateAndStoreKeyThatCannotBeUnprotected(_options.KeyManagement.KeyRetirementAge.Add(TimeSpan.FromSeconds(5))); + + var keys = await _subject.GetAllKeysFromStoreAsync(); + + keys.Select(x => x.Id).Should().BeEquivalentTo(new[] { key1, key2, key3, key4 }); + + _mockKeyStore.DeleteWasCalled.Should().BeTrue(); + var keysInStore = await _mockKeyStore.LoadKeysAsync(); + keysInStore.Select(x => x.Id).Should().BeEquivalentTo(new[] { key1, key2, key3, key4 }); + } + [Fact] public async Task GetKeysFromStoreAsync_should_filter_null_keys() { diff --git a/test/IdentityServer.UnitTests/Services/Default/KeyManagement/MockSigningKeyProtector.cs b/test/IdentityServer.UnitTests/Services/Default/KeyManagement/MockSigningKeyProtector.cs index 3e7de6ea5..119b10ed3 100644 --- a/test/IdentityServer.UnitTests/Services/Default/KeyManagement/MockSigningKeyProtector.cs +++ b/test/IdentityServer.UnitTests/Services/Default/KeyManagement/MockSigningKeyProtector.cs @@ -4,13 +4,21 @@ using Duende.IdentityServer.Models; using Duende.IdentityServer.Services.KeyManagement; +using Microsoft.AspNetCore.DataProtection; using System; namespace UnitTests.Services.Default.KeyManagement; class MockSigningKeyProtector : ISigningKeyProtector { + private IDataProtector _dataProtector; public bool ProtectWasCalled { get; set; } + + public MockSigningKeyProtector() + { + var provider = new EphemeralDataProtectionProvider(); + _dataProtector = provider.CreateProtector("test"); + } public SerializedKey Protect(KeyContainer key) { @@ -20,13 +28,32 @@ public SerializedKey Protect(KeyContainer key) Id = key.Id, Algorithm = key.Algorithm, IsX509Certificate = key.HasX509Certificate, - Created = DateTime.UtcNow, - Data = KeySerializer.Serialize(key), + Created = key.Created, + Data = _dataProtector.Protect(KeySerializer.Serialize(key)), + }; + } + + /// + /// Simulate a situation where a signing key was protected in the past with a signing key that is no longer available + /// + public SerializedKey ProtectAndLoseDataProtectionKey(KeyContainer key) + { + var provider = new EphemeralDataProtectionProvider(); + var badProtector = provider.CreateProtector("unavailable-when-we-unprotect"); + + ProtectWasCalled = true; + return new SerializedKey + { + Id = key.Id, + Algorithm = key.Algorithm, + IsX509Certificate = key.HasX509Certificate, + Created = key.Created, + Data = badProtector.Protect(KeySerializer.Serialize(key)), }; } public KeyContainer Unprotect(SerializedKey key) { - return KeySerializer.Deserialize(key.Data); + return KeySerializer.Deserialize(_dataProtector.Unprotect(key.Data)); } }