Skip to content

Commit

Permalink
Merge pull request #1506 from DuendeSoftware/anders/1479_GetAllUserGr…
Browse files Browse the repository at this point in the history
…antsAsync

Return successfully deserialized grants even if some fail
  • Loading branch information
brockallen committed Jan 5, 2024
2 parents a23764b + 2ce4286 commit 6012c36
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 11 deletions.
45 changes: 35 additions & 10 deletions src/IdentityServer/Services/Default/DefaultPersistedGrantService.cs
Expand Up @@ -28,7 +28,7 @@ public class DefaultPersistedGrantService : IPersistedGrantService
/// <param name="store">The store.</param>
/// <param name="serializer">The serializer.</param>
/// <param name="logger">The logger.</param>
public DefaultPersistedGrantService(IPersistedGrantStore store,
public DefaultPersistedGrantService(IPersistedGrantStore store,
IPersistentGrantSerializer serializer,
ILogger<DefaultPersistedGrantService> logger)
{
Expand All @@ -41,18 +41,34 @@ public class DefaultPersistedGrantService : IPersistedGrantService
public async Task<IEnumerable<Grant>> GetAllGrantsAsync(string subjectId)
{
using var activity = Tracing.ServiceActivitySource.StartActivity("DefaultPersistedGrantService.GetAllGrants");

if (String.IsNullOrWhiteSpace(subjectId)) throw new ArgumentNullException(nameof(subjectId));

var grants = (await _store.GetAllAsync(new PersistedGrantFilter { SubjectId = subjectId }))
.Where(x => x.ConsumedTime == null) // filter consumed grants
.ToArray();

List<Exception> errors = new List<Exception>();

T DeserializeAndCaptureErrors<T>(string data)
{
try
{
return _serializer.Deserialize<T>(data);
}
catch (Exception ex)
{
errors.Add(ex);
return default(T);
}
}

try
{
var consents = grants.Where(x => x.Type == IdentityServerConstants.PersistedGrantTypes.UserConsent)
.Select(x => _serializer.Deserialize<Consent>(x.Data))
.Select(x => new Grant
.Select(x => DeserializeAndCaptureErrors<Consent>(x.Data))
.Where(x => x != default)
.Select(x => new Grant
{
ClientId = x.ClientId,
SubjectId = subjectId,
Expand All @@ -62,7 +78,8 @@ public async Task<IEnumerable<Grant>> GetAllGrantsAsync(string subjectId)
});

var codes = grants.Where(x => x.Type == IdentityServerConstants.PersistedGrantTypes.AuthorizationCode)
.Select(x => _serializer.Deserialize<AuthorizationCode>(x.Data))
.Select(x => DeserializeAndCaptureErrors<AuthorizationCode>(x.Data))
.Where(x => x != default)
.Select(x => new Grant
{
ClientId = x.ClientId,
Expand All @@ -74,7 +91,8 @@ public async Task<IEnumerable<Grant>> GetAllGrantsAsync(string subjectId)
});

var refresh = grants.Where(x => x.Type == IdentityServerConstants.PersistedGrantTypes.RefreshToken)
.Select(x => _serializer.Deserialize<RefreshToken>(x.Data))
.Select(x => DeserializeAndCaptureErrors<RefreshToken>(x.Data))
.Where(x => x != default)
.Select(x => new Grant
{
ClientId = x.ClientId,
Expand All @@ -86,7 +104,8 @@ public async Task<IEnumerable<Grant>> GetAllGrantsAsync(string subjectId)
});

var access = grants.Where(x => x.Type == IdentityServerConstants.PersistedGrantTypes.ReferenceToken)
.Select(x => _serializer.Deserialize<Token>(x.Data))
.Select(x => DeserializeAndCaptureErrors<Token>(x.Data))
.Where(x => x != default)
.Select(x => new Grant
{
ClientId = x.ClientId,
Expand All @@ -101,6 +120,11 @@ public async Task<IEnumerable<Grant>> GetAllGrantsAsync(string subjectId)
consents = Join(consents, refresh);
consents = Join(consents, access);

if (errors.Count > 0)
{
_logger.LogError(new AggregateException(errors), "One or more errors occured during deserialization of persisted grants, returning successfull items.");
}

return consents.ToArray();
}
catch (Exception ex)
Expand All @@ -115,7 +139,7 @@ private IEnumerable<Grant> Join(IEnumerable<Grant> first, IEnumerable<Grant> sec
{
var list = first.ToList();

foreach(var other in second)
foreach (var other in second)
{
var match = list.FirstOrDefault(x => x.ClientId == other.ClientId);
if (match != null)
Expand Down Expand Up @@ -154,10 +178,11 @@ private IEnumerable<Grant> Join(IEnumerable<Grant> first, IEnumerable<Grant> sec
public Task RemoveAllGrantsAsync(string subjectId, string clientId = null, string sessionId = null)
{
using var activity = Tracing.ServiceActivitySource.StartActivity("DefaultPersistedGrantService.RemoveAllGrants");

if (String.IsNullOrWhiteSpace(subjectId)) throw new ArgumentNullException(nameof(subjectId));

return _store.RemoveAllAsync(new PersistedGrantFilter {
return _store.RemoveAllAsync(new PersistedGrantFilter
{
SubjectId = subjectId,
ClientId = clientId,
SessionId = sessionId
Expand Down
Expand Up @@ -5,6 +5,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Security.Claims;
using System.Threading.Tasks;
using Duende.IdentityServer;
Expand Down Expand Up @@ -540,4 +541,82 @@ await _userConsent.StoreUserConsentAsync(new Consent()
grants.Count().Should().Be(1);
grants.First().Scopes.Should().Contain(new string[] { "foo1", "foo2", "quux3" });
}
}

[Fact]
public async Task GetAllGrantsAsync_should_filter_items_with_corrupt_data_from_result()
{
var mockStore = new CorruptingPersistedGrantStore(_store)
{
ClientIdToCorrupt = "client2"
};

_subject = new DefaultPersistedGrantService(
mockStore,
new PersistentGrantSerializer(),
TestLogger.Create<DefaultPersistedGrantService>());

await _userConsent.StoreUserConsentAsync(new Consent()
{
ClientId = "client1",
SubjectId = "123",
Scopes = new string[] { "foo1", "foo2" }
});
await _userConsent.StoreUserConsentAsync(new Consent()
{
ClientId = "client2",
SubjectId = "123",
Scopes = new string[] { "foo3" }
});

var grants = await _subject.GetAllGrantsAsync("123");

grants.Count().Should().Be(1);
grants.First().Scopes.Should().Contain(new string[] { "foo1", "foo2" });
}

class CorruptingPersistedGrantStore : IPersistedGrantStore
{
public string ClientIdToCorrupt { get; set; }

private IPersistedGrantStore _inner;

public CorruptingPersistedGrantStore(IPersistedGrantStore inner)
{
_inner = inner;
}

public async Task<IEnumerable<PersistedGrant>> GetAllAsync(PersistedGrantFilter filter)
{
var items = await _inner.GetAllAsync(filter);
if (ClientIdToCorrupt != null)
{
var itemsToCorrupt = items.Where(x => x.ClientId == ClientIdToCorrupt);
foreach(var corruptItem in itemsToCorrupt)
{
corruptItem.Data = "corrupt";
}
}
return items;
}

public Task<PersistedGrant> GetAsync(string key)
{
return _inner.GetAsync(key);
}

public Task RemoveAllAsync(PersistedGrantFilter filter)
{
return _inner.RemoveAllAsync(filter);
}

public Task RemoveAsync(string key)
{
return _inner.RemoveAsync(key);
}

public Task StoreAsync(PersistedGrant grant)
{
return _inner.StoreAsync(grant);
}
}
}

0 comments on commit 6012c36

Please sign in to comment.