Skip to content

Commit

Permalink
Merge pull request #73255 from CyrusNajmabadi/cleanup
Browse files Browse the repository at this point in the history
Restore parallel processing options for Find-References.
  • Loading branch information
CyrusNajmabadi committed Apr 29, 2024
2 parents 75995e2 + b0ccdd2 commit fa64d90
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,13 @@ namespace Microsoft.CodeAnalysis.FindSymbols;

internal partial class FindReferencesSearchEngine
{
private static readonly ObjectPool<MetadataUnifyingSymbolHashSet> s_metadataUnifyingSymbolHashSetPool = new(() => []);

private readonly Solution _solution;
private readonly IImmutableSet<Document>? _documents;
private readonly ImmutableArray<IReferenceFinder> _finders;
private readonly IStreamingProgressTracker _progressTracker;
private readonly IStreamingFindReferencesProgress _progress;
private readonly FindReferencesSearchOptions _options;

/// <summary>
/// Scheduler to run our tasks on. If we're in <see cref="FindReferencesSearchOptions.Explicit"/> mode, we'll
/// run all our tasks concurrently. Otherwise, we will run them serially using <see cref="s_exclusiveScheduler"/>
/// </summary>
private readonly TaskScheduler _scheduler;
private static readonly TaskScheduler s_exclusiveScheduler = new ConcurrentExclusiveSchedulerPair().ExclusiveScheduler;

/// <summary>
Expand All @@ -60,14 +53,23 @@ internal partial class FindReferencesSearchEngine
_options = options;

_progressTracker = progress.ProgressTracker;

// If we're an explicit invocation, just defer to the threadpool to execute all our work in parallel to get
// things done as quickly as possible. If we're running implicitly, then use a
// ConcurrentExclusiveSchedulerPair's exclusive scheduler as that's the most built-in way in the TPL to get
// will run things serially.
_scheduler = _options.Explicit ? TaskScheduler.Default : s_exclusiveScheduler;
}

/// <summary>
/// Options to control the parallelism of the search. If we're in <see
/// cref="FindReferencesSearchOptions.Explicit"/> mode, we'll run all our tasks concurrently. Otherwise, we will
/// run them serially using <see cref="s_exclusiveScheduler"/>
/// </summary>
private ParallelOptions GetParallelOptions(CancellationToken cancellationToken)
=> new()
{
CancellationToken = cancellationToken,
// If we're an explicit invocation, just defer to the threadpool to execute all our work in parallel to get
// things done as quickly as possible. If we're running implicitly, then use a exclusive scheduler as
// that's the most built-in way in the TPL to get will run things serially.
TaskScheduler = _options.Explicit ? TaskScheduler.Default : s_exclusiveScheduler,
};

public Task FindReferencesAsync(ISymbol symbol, CancellationToken cancellationToken)
=> FindReferencesAsync([symbol], cancellationToken);

Expand Down Expand Up @@ -112,17 +114,12 @@ public Task FindReferencesAsync(ISymbol symbol, CancellationToken cancellationTo
// set of documents to search, we only bother with those.
var projectsToSearch = await GetProjectsToSearchAsync(allSymbols, cancellationToken).ConfigureAwait(false);

// We need to process projects in order when updating our symbol set. Say we have three projects (A, B
// and C), we cannot necessarily find inherited symbols in C until we have searched B. Importantly,
// while we're processing each project linearly to update the symbol set we're searching for, we still
// then process the projects in parallel once we know the set of symbols we're searching for in that
// project.
await _progressTracker.AddItemsAsync(projectsToSearch.Length, cancellationToken).ConfigureAwait(false);

// Pull off and start searching each project as soon as we can once we've done the inheritance cascade into it.
await RoslynParallel.ForEachAsync(
GetProjectsAndSymbolsToSearchAsync(symbolSet, projectsToSearch, cancellationToken),
cancellationToken,
GetParallelOptions(cancellationToken),
async (tuple, cancellationToken) => await ProcessProjectAsync(
tuple.project, tuple.allSymbols, onReferenceFound, cancellationToken).ConfigureAwait(false)).ConfigureAwait(false);
}
Expand All @@ -132,6 +129,11 @@ public Task FindReferencesAsync(ISymbol symbol, CancellationToken cancellationTo
ImmutableArray<Project> projectsToSearch,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
// We need to process projects in order when updating our symbol set. Say we have three projects (A, B
// and C), we cannot necessarily find inherited symbols in C until we have searched B. Importantly,
// while we're processing each project linearly to update the symbol set we're searching for, we still
// then process the projects in parallel once we know the set of symbols we're searching for in that
// project.
var dependencyGraph = _solution.GetProjectDependencyGraph();
foreach (var projectId in dependencyGraph.GetTopologicallySortedProjects(cancellationToken))
{
Expand All @@ -153,9 +155,6 @@ public Task FindReferencesAsync(ISymbol symbol, CancellationToken cancellationTo
}
}

public Task CreateWorkAsync(Func<Task> createWorkAsync, CancellationToken cancellationToken)
=> Task.Factory.StartNew(createWorkAsync, cancellationToken, TaskCreationOptions.None, _scheduler).Unwrap();

/// <summary>
/// Notify the caller of the engine about the definitions we've found that we're looking for. We'll only notify
/// them once per symbol group, but we may have to notify about new symbols each time we expand our symbol set
Expand Down Expand Up @@ -231,44 +230,30 @@ private async ValueTask<SymbolGroup> ReportGroupAsync(ISymbol symbol, Cancellati
_options, cancellationToken).ConfigureAwait(false);

foreach (var document in foundDocuments)
{
var docSymbols = GetSymbolSet(documentToSymbols, document);
docSymbols.Add(symbol);
}
GetSymbolSet(documentToSymbols, document).Add(symbol);

foundDocuments.Clear();
}
}

await RoslynParallel.ForEachAsync(
documentToSymbols,
cancellationToken,
GetParallelOptions(cancellationToken),
(kvp, cancellationToken) =>
ProcessDocumentAsync(kvp.Key, kvp.Value, symbolToGlobalAliases, onReferenceFound, cancellationToken)).ConfigureAwait(false);
}
finally
{
foreach (var (_, symbols) in documentToSymbols)
{
symbols.Clear();
s_metadataUnifyingSymbolHashSetPool.Free(symbols);
}
MetadataUnifyingSymbolHashSet.ClearAndFree(symbols);

FreeGlobalAliases(symbolToGlobalAliases);

await _progressTracker.ItemCompletedAsync(cancellationToken).ConfigureAwait(false);
}

static MetadataUnifyingSymbolHashSet GetSymbolSet<T>(PooledDictionary<T, MetadataUnifyingSymbolHashSet> dictionary, T key) where T : notnull
{
if (!dictionary.TryGetValue(key, out var set))
{
set = s_metadataUnifyingSymbolHashSetPool.Allocate();
dictionary.Add(key, set);
}

return set;
}
=> dictionary.GetOrAdd(key, static _ => MetadataUnifyingSymbolHashSet.AllocateFromPool());
}

private static PooledHashSet<U>? TryGet<T, U>(Dictionary<T, PooledHashSet<U>> dictionary, T key) where T : notnull
Expand Down Expand Up @@ -304,13 +289,13 @@ private async ValueTask<SymbolGroup> ReportGroupAsync(ISymbol symbol, Cancellati

await RoslynParallel.ForEachAsync(
symbols,
cancellationToken,
GetParallelOptions(cancellationToken),
async (symbol, cancellationToken) =>
{
// symbolToGlobalAliases is safe to read in parallel. It is created fully before this point and is no
// longer mutated.
var globalAliases = TryGet(symbolToGlobalAliases, symbol);
var state = new FindReferencesDocumentState(cache, globalAliases);
var state = new FindReferencesDocumentState(
cache, TryGet(symbolToGlobalAliases, symbol));
await ProcessDocumentAsync(symbol, state, onReferenceFound).ConfigureAwait(false);
}).ConfigureAwait(false);
Expand Down Expand Up @@ -357,24 +342,13 @@ private async ValueTask<SymbolGroup> ReportGroupAsync(ISymbol symbol, Cancellati
var aliases = await finder.DetermineGlobalAliasesAsync(
symbol, project, cancellationToken).ConfigureAwait(false);
if (aliases.Length > 0)
{
var globalAliases = GetGlobalAliasesSet(symbolToGlobalAliases, symbol);
globalAliases.AddRange(aliases);
}
GetGlobalAliasesSet(symbolToGlobalAliases, symbol).AddRange(aliases);
}
}
}

private static PooledHashSet<string> GetGlobalAliasesSet<T>(PooledDictionary<T, PooledHashSet<string>> dictionary, T key) where T : notnull
{
if (!dictionary.TryGetValue(key, out var set))
{
set = PooledHashSet<string>.GetInstance();
dictionary.Add(key, set);
}

return set;
}
=> dictionary.GetOrAdd(key, static _ => PooledHashSet<string>.GetInstance());

private static void FreeGlobalAliases(PooledDictionary<ISymbol, PooledHashSet<string>> symbolToGlobalAliases)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ async ValueTask PerformSearchInProjectAsync(ImmutableArray<ISymbol> symbols, Pro

foreach (var symbol in symbols)
{
var globalAliases = TryGet(symbolToGlobalAliases, symbol);
var state = new FindReferencesDocumentState(cache, globalAliases);
var state = new FindReferencesDocumentState(
cache, TryGet(symbolToGlobalAliases, symbol));

await PerformSearchInDocumentWorkerAsync(symbol, state).ConfigureAwait(false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,24 @@
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using Microsoft.CodeAnalysis.PooledObjects;

namespace Microsoft.CodeAnalysis.FindSymbols;

internal sealed class MetadataUnifyingSymbolHashSet : HashSet<ISymbol>
{
private static readonly ObjectPool<MetadataUnifyingSymbolHashSet> s_metadataUnifyingSymbolHashSetPool = new(() => []);

public MetadataUnifyingSymbolHashSet() : base(MetadataUnifyingEquivalenceComparer.Instance)
{
}

public static MetadataUnifyingSymbolHashSet AllocateFromPool()
=> s_metadataUnifyingSymbolHashSetPool.Allocate();

public static void ClearAndFree(MetadataUnifyingSymbolHashSet set)
{
set.Clear();
s_metadataUnifyingSymbolHashSetPool.Free(set);
}
}
54 changes: 38 additions & 16 deletions src/Workspaces/Core/Portable/Shared/Utilities/RoslynParallel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,59 +10,81 @@

namespace Microsoft.CodeAnalysis.Shared.Utilities;

#pragma warning disable CA1068 // CancellationToken parameters must come last

internal static class RoslynParallel
{
#pragma warning disable CA1068 // CancellationToken parameters must come last
public static async Task ForEachAsync<TSource>(
#pragma warning restore CA1068 // CancellationToken parameters must come last
public static Task ForEachAsync<TSource>(
IEnumerable<TSource> source,
CancellationToken cancellationToken,
Func<TSource, CancellationToken, ValueTask> body)
{
return ForEachAsync(source, GetParallelOptions(cancellationToken), body);
}

private static ParallelOptions GetParallelOptions(CancellationToken cancellationToken)
=> new() { TaskScheduler = TaskScheduler.Default, CancellationToken = cancellationToken };

public static async Task ForEachAsync<TSource>(
IEnumerable<TSource> source,
ParallelOptions parallelOptions,
Func<TSource, CancellationToken, ValueTask> body)
{
var cancellationToken = parallelOptions.CancellationToken;
if (cancellationToken.IsCancellationRequested)
return;

#if NET
await Parallel.ForEachAsync(source, cancellationToken, body).ConfigureAwait(false);
await Parallel.ForEachAsync(source, parallelOptions, body).ConfigureAwait(false);
#else
using var _ = ArrayBuilder<Task>.GetInstance(out var tasks);

foreach (var item in source)
{
tasks.Add(Task.Run(async () =>
{
await body(item, cancellationToken).ConfigureAwait(false);
}, cancellationToken));
tasks.Add(CreateWorkAsync(
parallelOptions.TaskScheduler,
async () => await body(item, cancellationToken).ConfigureAwait(false),
cancellationToken));
}

await Task.WhenAll(tasks).ConfigureAwait(false);
#endif
}

#pragma warning disable CA1068 // CancellationToken parameters must come last
public static async Task ForEachAsync<TSource>(
#pragma warning restore CA1068 // CancellationToken parameters must come last
public static Task ForEachAsync<TSource>(
IAsyncEnumerable<TSource> source,
CancellationToken cancellationToken,
Func<TSource, CancellationToken, ValueTask> body)
{
return ForEachAsync(source, GetParallelOptions(cancellationToken), body);
}

public static async Task ForEachAsync<TSource>(
IAsyncEnumerable<TSource> source,
ParallelOptions parallelOptions,
Func<TSource, CancellationToken, ValueTask> body)
{
var cancellationToken = parallelOptions.CancellationToken;
if (cancellationToken.IsCancellationRequested)
return;

#if NET
await Parallel.ForEachAsync(source, cancellationToken, body).ConfigureAwait(false);
await Parallel.ForEachAsync(source, parallelOptions, body).ConfigureAwait(false);
#else
using var _ = ArrayBuilder<Task>.GetInstance(out var tasks);

await foreach (var item in source)
{
tasks.Add(Task.Run(async () =>
{
await body(item, cancellationToken).ConfigureAwait(false);
}, cancellationToken));
tasks.Add(CreateWorkAsync(
parallelOptions.TaskScheduler,
async () => await body(item, cancellationToken).ConfigureAwait(false),
cancellationToken));
}

await Task.WhenAll(tasks).ConfigureAwait(false);
#endif
}

public static Task CreateWorkAsync(TaskScheduler scheduler, Func<Task> createWorkAsync, CancellationToken cancellationToken)
=> Task.Factory.StartNew(createWorkAsync, cancellationToken, TaskCreationOptions.None, scheduler).Unwrap();
}

0 comments on commit fa64d90

Please sign in to comment.