Skip to content

Commit

Permalink
Merge pull request #4017 from Particular/john/scatter
Browse files Browse the repository at this point in the history
Make ScatterGather not depend on HttpContext
  • Loading branch information
johnsimons committed Mar 20, 2024
2 parents 4ac3a62 + d1bf96a commit e870991
Show file tree
Hide file tree
Showing 16 changed files with 166 additions and 94 deletions.
Expand Up @@ -5,18 +5,17 @@
using System.Linq;
using System.Net.Http;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using CompositeViews.Messages;
using NUnit.Framework;
using Persistence.Infrastructure;
using ServiceBus.Management.Infrastructure.Settings;
using ServiceControl.CompositeViews.Messages;
using ServiceControl.Persistence.Infrastructure;

abstract class MessageView_ScatterGatherTest
{
[SetUp]
public void SetUp()
{
var api = new TestApi(null, null, null, null);
var api = new TestApi(null, null, null);

Results = api.AggregateResults(new ScatterGatherApiMessageViewContext(null, new SortInfo()), GetData());
}
Expand Down Expand Up @@ -67,8 +66,8 @@ protected IEnumerable<MessagesView> RemoteData()

class TestApi : ScatterGatherApiMessageView<object, ScatterGatherApiMessageViewContext>
{
public TestApi(object dataStore, Settings settings, IHttpClientFactory httpClientFactory, IHttpContextAccessor httpContextAccessor)
: base(dataStore, settings, httpClientFactory, httpContextAccessor)
public TestApi(object dataStore, Settings settings, IHttpClientFactory httpClientFactory)
: base(dataStore, settings, httpClientFactory)
{
}

Expand Down
Expand Up @@ -4,11 +4,10 @@
using System.Linq;
using System.Net.Http;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Messages;
using Persistence;
using Persistence.Infrastructure;
using ServiceBus.Management.Infrastructure.Settings;
using ServiceControl.CompositeViews.Messages;
using ServiceControl.Persistence;
using ServiceControl.Persistence.Infrastructure;

// The endpoint is included for consistency reasons but is actually not required here because the query
// is forwarded to the remote instance. But this at least enforces us to declare the controller action
Expand All @@ -19,10 +18,10 @@ public record AuditCountsForEndpointContext(PagingInfo PagingInfo, string Endpoi
public class GetAuditCountsForEndpointApi(
IErrorMessageDataStore dataStore,
Settings settings,
IHttpClientFactory httpClientFactory,
IHttpContextAccessor httpContextAccessor)
IHttpClientFactory httpClientFactory
)
: ScatterGatherApi<IErrorMessageDataStore, AuditCountsForEndpointContext, IList<AuditCount>>(dataStore, settings,
httpClientFactory, httpContextAccessor)
httpClientFactory)
{
static readonly IList<AuditCount> Empty = new List<AuditCount>(0).AsReadOnly();

Expand Down
Expand Up @@ -3,14 +3,14 @@ namespace ServiceControl.CompositeViews.Messages
using System.Collections.Generic;
using System.Net.Http;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Persistence;
using Persistence.Infrastructure;
using ServiceBus.Management.Infrastructure.Settings;
using ServiceControl.Persistence;
using ServiceControl.Persistence.Infrastructure;

public class GetAllMessagesApi : ScatterGatherApiMessageView<IErrorMessageDataStore, ScatterGatherApiMessageViewWithSystemMessagesContext>
{
public GetAllMessagesApi(IErrorMessageDataStore dataStore, Settings settings, IHttpClientFactory httpClientFactory, IHttpContextAccessor httpContextAccessor) : base(dataStore, settings, httpClientFactory, httpContextAccessor)
public GetAllMessagesApi(IErrorMessageDataStore dataStore, Settings settings,
IHttpClientFactory httpClientFactory) : base(dataStore, settings, httpClientFactory)
{
}

Expand Down
Expand Up @@ -3,10 +3,9 @@ namespace ServiceControl.CompositeViews.Messages
using System.Collections.Generic;
using System.Net.Http;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Persistence;
using Persistence.Infrastructure;
using ServiceBus.Management.Infrastructure.Settings;
using ServiceControl.Persistence;
using ServiceControl.Persistence.Infrastructure;

public record AllMessagesForEndpointContext(
PagingInfo PagingInfo,
Expand All @@ -17,7 +16,8 @@ public record AllMessagesForEndpointContext(

public class GetAllMessagesForEndpointApi : ScatterGatherApiMessageView<IErrorMessageDataStore, AllMessagesForEndpointContext>
{
public GetAllMessagesForEndpointApi(IErrorMessageDataStore dataStore, Settings settings, IHttpClientFactory httpClientFactory, IHttpContextAccessor httpContextAccessor) : base(dataStore, settings, httpClientFactory, httpContextAccessor)
public GetAllMessagesForEndpointApi(IErrorMessageDataStore dataStore, Settings settings,
IHttpClientFactory httpClientFactory) : base(dataStore, settings, httpClientFactory)
{
}

Expand Down
Expand Up @@ -2,6 +2,8 @@
{
using System.Collections.Generic;
using System.Threading.Tasks;
using Infrastructure.WebApi;
using Microsoft.AspNetCore.Http.Extensions;
using Microsoft.AspNetCore.Mvc;
using Persistence.Infrastructure;

Expand All @@ -12,8 +14,17 @@ public class GetMessagesByConversationController(MessagesByConversationApi byCon
{
[Route("conversations/{conversationId:required:minlength(1)}")]
[HttpGet]
public Task<IList<MessagesView>> Messages([FromQuery] PagingInfo pagingInfo, [FromQuery] SortInfo sortInfo,
[FromQuery(Name = "include_system_messages")] bool includeSystemMessages, string conversationId) =>
byConversationApi.Execute(new(pagingInfo, sortInfo, includeSystemMessages, conversationId));
public async Task<IList<MessagesView>> Messages([FromQuery] PagingInfo pagingInfo,
[FromQuery] SortInfo sortInfo,
[FromQuery(Name = "include_system_messages")]
bool includeSystemMessages, string conversationId)
{
QueryResult<IList<MessagesView>> result = await byConversationApi.Execute(
new MessagesByConversationContext(pagingInfo, sortInfo, includeSystemMessages, conversationId),
Request.GetEncodedPathAndQuery());

Response.WithQueryStatsAndPagingInfo(result.QueryStats, pagingInfo);
return result.Results;
}
}
}
90 changes: 72 additions & 18 deletions src/ServiceControl/CompositeViews/Messages/GetMessagesController.cs
Expand Up @@ -4,15 +4,15 @@ namespace ServiceControl.CompositeViews.Messages
using System.Linq;
using System.Net.Http;
using System.Threading.Tasks;
using Infrastructure;
using Infrastructure.WebApi;
using MessageCounting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Extensions;
using Microsoft.AspNetCore.Mvc;
using NServiceBus.Logging;
using Operations.BodyStorage;
using Persistence.Infrastructure;
using ServiceBus.Management.Infrastructure.Settings;
using ServiceControl.CompositeViews.MessageCounting;
using ServiceControl.Operations.BodyStorage;
using Yarp.ReverseProxy.Forwarder;

// All routes matching `messages/*` must be in this controller as WebAPI cannot figure out the overlapping routes
Expand All @@ -33,21 +33,45 @@ public class GetMessagesController(
{
[Route("messages")]
[HttpGet]
public Task<IList<MessagesView>> Messages([FromQuery] PagingInfo pagingInfo, [FromQuery] SortInfo sortInfo,
[FromQuery(Name = "include_system_messages")] bool includeSystemMessages) => allMessagesApi.Execute(
new(pagingInfo, sortInfo, includeSystemMessages));
public async Task<IList<MessagesView>> Messages([FromQuery] PagingInfo pagingInfo,
[FromQuery] SortInfo sortInfo,
[FromQuery(Name = "include_system_messages")]
bool includeSystemMessages)
{
QueryResult<IList<MessagesView>> result = await allMessagesApi.Execute(
new ScatterGatherApiMessageViewWithSystemMessagesContext(pagingInfo, sortInfo, includeSystemMessages),
Request.GetEncodedPathAndQuery());

Response.WithQueryStatsAndPagingInfo(result.QueryStats, pagingInfo);
return result.Results;
}

[Route("endpoints/{endpoint}/messages")]
[HttpGet]
public Task<IList<MessagesView>> MessagesForEndpoint([FromQuery] PagingInfo pagingInfo, [FromQuery] SortInfo sortInfo,
[FromQuery(Name = "include_system_messages")] bool includeSystemMessages, string endpoint) =>
allMessagesForEndpointApi.Execute(new(pagingInfo, sortInfo, includeSystemMessages, endpoint));
public async Task<IList<MessagesView>> MessagesForEndpoint([FromQuery] PagingInfo pagingInfo,
[FromQuery] SortInfo sortInfo,
[FromQuery(Name = "include_system_messages")]
bool includeSystemMessages, string endpoint)
{
QueryResult<IList<MessagesView>> result = await allMessagesForEndpointApi.Execute(
new AllMessagesForEndpointContext(pagingInfo, sortInfo, includeSystemMessages, endpoint),
Request.GetEncodedPathAndQuery());

Response.WithQueryStatsAndPagingInfo(result.QueryStats, pagingInfo);
return result.Results;
}

// the endpoint name is needed in the route to match the route and forward it as path and query to the remotes
[Route("endpoints/{endpoint}/audit-count")]
[HttpGet]
public Task<IList<AuditCount>> GetEndpointAuditCounts([FromQuery] PagingInfo pagingInfo, string endpoint) =>
auditCountsForEndpointApi.Execute(new(pagingInfo, endpoint));
public async Task<IList<AuditCount>> GetEndpointAuditCounts([FromQuery] PagingInfo pagingInfo, string endpoint)
{
QueryResult<IList<AuditCount>> result = await auditCountsForEndpointApi.Execute(
new AuditCountsForEndpointContext(pagingInfo, endpoint), Request.GetEncodedPathAndQuery());

Response.WithQueryStatsAndPagingInfo(result.QueryStats, pagingInfo);
return result.Results;
}

[Route("messages/{id}/body")]
[HttpGet]
Expand Down Expand Up @@ -89,22 +113,52 @@ public async Task<IActionResult> Get(string id, [FromQuery(Name = "instance_id")

[Route("messages/search")]
[HttpGet]
public Task<IList<MessagesView>> Search([FromQuery] PagingInfo pagingInfo, [FromQuery] SortInfo sortInfo, string q) => api.Execute(new(pagingInfo, sortInfo, q));
public async Task<IList<MessagesView>> Search([FromQuery] PagingInfo pagingInfo, [FromQuery] SortInfo sortInfo,
string q)
{
QueryResult<IList<MessagesView>> result = await api.Execute(new SearchApiContext(pagingInfo, sortInfo, q),
Request.GetEncodedPathAndQuery());

Response.WithQueryStatsAndPagingInfo(result.QueryStats, pagingInfo);
return result.Results;
}

[Route("messages/search/{keyword}")]
[HttpGet]
public Task<IList<MessagesView>> SearchByKeyWord([FromQuery] PagingInfo pagingInfo, [FromQuery] SortInfo sortInfo, string keyword) =>
api.Execute(new(pagingInfo, sortInfo, keyword?.Replace("/", @"\")));
public async Task<IList<MessagesView>> SearchByKeyWord([FromQuery] PagingInfo pagingInfo,
[FromQuery] SortInfo sortInfo, string keyword)
{
QueryResult<IList<MessagesView>> result = await api.Execute(
new SearchApiContext(pagingInfo, sortInfo, keyword?.Replace("/", @"\")),
Request.GetEncodedPathAndQuery());

Response.WithQueryStatsAndPagingInfo(result.QueryStats, pagingInfo);
return result.Results;
}

[Route("endpoints/{endpoint}/messages/search")]
[HttpGet]
public Task<IList<MessagesView>> Search([FromQuery] PagingInfo pagingInfo, [FromQuery] SortInfo sortInfo, string endpoint, string q) =>
endpointApi.Execute(new(pagingInfo, sortInfo, endpoint, q));
public async Task<IList<MessagesView>> Search([FromQuery] PagingInfo pagingInfo, [FromQuery] SortInfo sortInfo,
string endpoint, string q)
{
QueryResult<IList<MessagesView>> result = await endpointApi.Execute(
new SearchEndpointContext(pagingInfo, sortInfo, endpoint, q), Request.GetEncodedPathAndQuery());

Response.WithQueryStatsAndPagingInfo(result.QueryStats, pagingInfo);
return result.Results;
}

[Route("endpoints/{endpoint}/messages/search/{keyword}")]
[HttpGet]
public Task<IList<MessagesView>> SearchByKeyword([FromQuery] PagingInfo pagingInfo, [FromQuery] SortInfo sortInfo, string endpoint, string keyword) =>
endpointApi.Execute(new(pagingInfo, sortInfo, endpoint, keyword));
public async Task<IList<MessagesView>> SearchByKeyword([FromQuery] PagingInfo pagingInfo,
[FromQuery] SortInfo sortInfo, string endpoint, string keyword)
{
QueryResult<IList<MessagesView>> result = await endpointApi.Execute(
new SearchEndpointContext(pagingInfo, sortInfo, endpoint, keyword), Request.GetEncodedPathAndQuery());

Response.WithQueryStatsAndPagingInfo(result.QueryStats, pagingInfo);
return result.Results;
}

static ILog logger = LogManager.GetLogger(typeof(GetMessagesController));
}
Expand Down
Expand Up @@ -3,10 +3,9 @@ namespace ServiceControl.CompositeViews.Messages
using System.Collections.Generic;
using System.Net.Http;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Persistence;
using Persistence.Infrastructure;
using ServiceBus.Management.Infrastructure.Settings;
using ServiceControl.Persistence;
using ServiceControl.Persistence.Infrastructure;

public record MessagesByConversationContext(
PagingInfo PagingInfo,
Expand All @@ -18,8 +17,8 @@ public record MessagesByConversationContext(
public class MessagesByConversationApi : ScatterGatherApiMessageView<IErrorMessageDataStore, MessagesByConversationContext>
{
public MessagesByConversationApi(IErrorMessageDataStore dataStore, Settings settings,
IHttpClientFactory httpClientFactory, IHttpContextAccessor httpContextAccessor) : base(dataStore, settings,
httpClientFactory, httpContextAccessor)
IHttpClientFactory httpClientFactory) : base(dataStore, settings,
httpClientFactory)
{
}

Expand Down
27 changes: 10 additions & 17 deletions src/ServiceControl/CompositeViews/Messages/ScatterGatherApi.cs
Expand Up @@ -5,14 +5,11 @@ namespace ServiceControl.CompositeViews.Messages
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Text.Json;
using System.Threading.Tasks;
using Infrastructure.WebApi;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Extensions;
using NServiceBus.Logging;
using Persistence.Infrastructure;
using ServiceBus.Management.Infrastructure.Settings;
using ServiceControl.Persistence.Infrastructure;
using JsonSerializer = System.Text.Json.JsonSerializer;

interface IApi
Expand All @@ -30,23 +27,21 @@ public abstract class ScatterGatherApi<TDataStore, TIn, TOut> : ScatterGatherApi
where TIn : ScatterGatherContext
where TOut : class
{
protected ScatterGatherApi(TDataStore store, Settings settings, IHttpClientFactory httpClientFactory, IHttpContextAccessor httpContextAccessor)
protected ScatterGatherApi(TDataStore store, Settings settings, IHttpClientFactory httpClientFactory)
{
this.httpContextAccessor = httpContextAccessor;
DataStore = store;
Settings = settings;
HttpClientFactory = httpClientFactory;
Logger = LogManager.GetLogger(GetType());
logger = LogManager.GetLogger(GetType());
}

protected TDataStore DataStore { get; }
Settings Settings { get; }
IHttpClientFactory HttpClientFactory { get; }

public async Task<TOut> Execute(TIn input)
public async Task<QueryResult<TOut>> Execute(TIn input, string pathAndQuery)
{
var remotes = Settings.RemoteInstances;
var pathAndQuery = httpContextAccessor.HttpContext!.Request.GetEncodedPathAndQuery();
var instanceId = Settings.InstanceId;
var tasks = new List<Task<QueryResult<TOut>>>(remotes.Length + 1)
{
Expand All @@ -65,9 +60,7 @@ public async Task<TOut> Execute(TIn input)
var results = await Task.WhenAll(tasks);
var response = AggregateResults(input, results);

httpContextAccessor.HttpContext!.Response.WithQueryStatsAndPagingInfo(response.QueryStats, input.PagingInfo);

return response.Results;
return response;
}

async Task<QueryResult<TOut>> LocalCall(TIn input, string instanceId)
Expand Down Expand Up @@ -127,18 +120,19 @@ async Task<QueryResult<TOut>> FetchAndParse(HttpClient httpClient, string pathAn
catch (HttpRequestException httpRequestException)
{
remoteInstanceSetting.TemporarilyUnavailable = true;
Logger.Warn($"An HttpRequestException occurred when querying remote instance at {remoteInstanceSetting.BaseAddress}. The instance at uri: {remoteInstanceSetting.BaseAddress} will be temporarily disabled.",
logger.Warn(
$"An HttpRequestException occurred when querying remote instance at {remoteInstanceSetting.BaseAddress}. The instance at uri: {remoteInstanceSetting.BaseAddress} will be temporarily disabled.",
httpRequestException);
return QueryResult<TOut>.Empty();
}
catch (OperationCanceledException)
{
Logger.Warn($"Failed to query remote instance at {remoteInstanceSetting.BaseAddress} due to a timeout");
logger.Warn($"Failed to query remote instance at {remoteInstanceSetting.BaseAddress} due to a timeout");
return QueryResult<TOut>.Empty();
}
catch (Exception exception)
{
Logger.Warn($"Failed to query remote instance at {remoteInstanceSetting.BaseAddress}.", exception);
logger.Warn($"Failed to query remote instance at {remoteInstanceSetting.BaseAddress}.", exception);
return QueryResult<TOut>.Empty();
}
}
Expand Down Expand Up @@ -168,7 +162,6 @@ static async Task<QueryResult<TOut>> ParseResult(HttpResponseMessage responseMes
return new QueryResult<TOut>(remoteResults, new QueryStatsInfo(etag, totalCount, isStale: false));
}

readonly ILog Logger;
readonly IHttpContextAccessor httpContextAccessor;
readonly ILog logger;
}
}
Expand Up @@ -3,9 +3,8 @@ namespace ServiceControl.CompositeViews.Messages
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using Microsoft.AspNetCore.Http;
using Persistence.Infrastructure;
using ServiceBus.Management.Infrastructure.Settings;
using ServiceControl.Persistence.Infrastructure;

public record ScatterGatherApiMessageViewWithSystemMessagesContext(
PagingInfo PagingInfo,
Expand All @@ -17,7 +16,8 @@ public record ScatterGatherApiMessageViewContext(PagingInfo PagingInfo, SortInfo
public abstract class ScatterGatherApiMessageView<TDataStore, TInput> : ScatterGatherApi<TDataStore, TInput, IList<MessagesView>>
where TInput : ScatterGatherApiMessageViewContext
{
protected ScatterGatherApiMessageView(TDataStore dataStore, Settings settings, IHttpClientFactory httpClientFactory, IHttpContextAccessor httpContextAccessor) : base(dataStore, settings, httpClientFactory, httpContextAccessor)
protected ScatterGatherApiMessageView(TDataStore dataStore, Settings settings,
IHttpClientFactory httpClientFactory) : base(dataStore, settings, httpClientFactory)
{
}

Expand Down

0 comments on commit e870991

Please sign in to comment.