Skip to content

Commit

Permalink
add missing QueryUnbufferedAsync<T> API (#1912)
Browse files Browse the repository at this point in the history
* impl QueryUnbufferedAsync<T>
* implement GridReader.ReadUnbufferedAsync<T>
  • Loading branch information
mgravell committed Jun 9, 2023
1 parent 194a0ce commit d56340b
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 31 deletions.
75 changes: 75 additions & 0 deletions Dapper/SqlMapper.Async.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Data.Common;
using System.Globalization;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

Expand Down Expand Up @@ -1217,5 +1218,79 @@ private static async Task<T> ExecuteScalarImplAsync<T>(IDbConnection cnn, Comman
}
return Parse<T>(result);
}

#if NET5_0_OR_GREATER
/// <summary>
/// Execute a query asynchronously using <see cref="IAsyncEnumerable{T}"/>.
/// </summary>
/// <typeparam name="T">The type of results to return.</typeparam>
/// <param name="cnn">The connection to query on.</param>
/// <param name="sql">The SQL to execute for the query.</param>
/// <param name="param">The parameters to pass, if any.</param>
/// <param name="transaction">The transaction to use, if any.</param>
/// <param name="commandTimeout">The command timeout (in seconds).</param>
/// <param name="commandType">The type of command to execute.</param>
/// <returns>
/// A sequence of data of <typeparamref name="T"/>; if a basic type (int, string, etc) is queried then the data from the first column is assumed, otherwise an instance is
/// created per row, and a direct column-name===member-name mapping is assumed (case insensitive).
/// </returns>
public static IAsyncEnumerable<T> QueryUnbufferedAsync<T>(this DbConnection cnn, string sql, object param = null, DbTransaction transaction = null, int? commandTimeout = null, CommandType? commandType = null)
{
// note: in many cases of adding a new async method I might add a CancellationToken - however, cancellation is expressed via WithCancellation on iterators
return QueryUnbufferedAsync<T>(cnn, typeof(T), new CommandDefinition(sql, param, transaction, commandTimeout, commandType, CommandFlags.None, default));
}

private static IAsyncEnumerable<T> QueryUnbufferedAsync<T>(this IDbConnection cnn, Type effectiveType, CommandDefinition command)
{
return Impl(cnn, effectiveType, command, command.CancellationToken); // proxy to allow CT expression

static async IAsyncEnumerable<T> Impl(IDbConnection cnn, Type effectiveType, CommandDefinition command,
[EnumeratorCancellation] CancellationToken cancel)
{
object param = command.Parameters;
var identity = new Identity(command.CommandText, command.CommandType, cnn, effectiveType, param?.GetType());
var info = GetCacheInfo(identity, param, command.AddToCache);
bool wasClosed = cnn.State == ConnectionState.Closed;
using var cmd = command.TrySetupAsyncCommand(cnn, info.ParamReader);
DbDataReader reader = null;
try
{
if (wasClosed) await cnn.TryOpenAsync(cancel).ConfigureAwait(false);
reader = await ExecuteReaderWithFlagsFallbackAsync(cmd, wasClosed, CommandBehavior.SequentialAccess | CommandBehavior.SingleResult, cancel).ConfigureAwait(false);

var tuple = info.Deserializer;
int hash = GetColumnHash(reader);
if (tuple.Func == null || tuple.Hash != hash)
{
if (reader.FieldCount == 0)
{
yield break;
}
tuple = info.Deserializer = new DeserializerState(hash, GetDeserializer(effectiveType, reader, 0, -1, false));
if (command.AddToCache) SetQueryCache(identity, info);
}

var func = tuple.Func;

var convertToType = Nullable.GetUnderlyingType(effectiveType) ?? effectiveType;
while (await reader.ReadAsync(cancel).ConfigureAwait(false))
{
object val = func(reader);
yield return GetValue<T>(reader, effectiveType, val);
}
while (await reader.NextResultAsync(cancel).ConfigureAwait(false)) { /* ignore subsequent result sets */ }
command.OnCompleted();
}
finally
{
if (reader is not null)
{
await reader.DisposeAsync();
}
if (wasClosed) cnn.Close();
}
}
}
#endif
}
}
110 changes: 88 additions & 22 deletions Dapper/SqlMapper.GridReader.Async.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Globalization;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

Expand All @@ -12,6 +12,9 @@ namespace Dapper
public static partial class SqlMapper
{
public partial class GridReader
#if NET5_0_OR_GREATER
: IAsyncDisposable
#endif
{
private readonly CancellationToken cancel;
internal GridReader(IDbCommand command, DbDataReader reader, Identity identity, DynamicParameters dynamicParams, bool addToCache, CancellationToken cancel)
Expand Down Expand Up @@ -140,7 +143,7 @@ public Task<object> ReadSingleOrDefaultAsync(Type type)

private async Task NextResultAsync()
{
if (await ((DbDataReader)reader).NextResultAsync(cancel).ConfigureAwait(false))
if (await reader.NextResultAsync(cancel).ConfigureAwait(false))
{
// readCount++;
gridIndex++;
Expand All @@ -150,14 +153,37 @@ private async Task NextResultAsync()
{
// happy path; close the reader cleanly - no
// need for "Cancel" etc
#if NET5_0_OR_GREATER
await reader.DisposeAsync();
#else
reader.Dispose();
#endif
reader = null;
callbacks?.OnCompleted();
#if NET5_0_OR_GREATER
await DisposeAsync();
#else
Dispose();
#endif
}
}

private Task<IEnumerable<T>> ReadAsyncImpl<T>(Type type, bool buffered)
{
var deserializer = ValidateAndMarkConsumed(type);
if (buffered)
{
return ReadBufferedAsync<T>(gridIndex, deserializer);
}
else
{
var result = ReadDeferred<T>(gridIndex, deserializer, type);
if (buffered) result = result?.ToList(); // for the "not a DbDataReader" scenario
return Task.FromResult(result);
}
}

private Func<DbDataReader, object> ValidateAndMarkConsumed(Type type)
{
if (reader == null) throw new ObjectDisposedException(GetType().FullName, "The reader has been disposed; this can happen after all data has been consumed");
if (IsConsumed) throw new InvalidOperationException("Query results must be consumed in the correct order, and each result can only be consumed once");
Expand All @@ -172,27 +198,10 @@ private Task<IEnumerable<T>> ReadAsyncImpl<T>(Type type, bool buffered)
cache.Deserializer = deserializer;
}
IsConsumed = true;
if (buffered && reader is DbDataReader)
{
return ReadBufferedAsync<T>(gridIndex, deserializer.Func);
}
else
{
var result = ReadDeferred<T>(gridIndex, deserializer.Func, type);
if (buffered) result = result?.ToList(); // for the "not a DbDataReader" scenario
return Task.FromResult(result);
}
}

private Task<T> ReadRowAsyncImpl<T>(Type type, Row row)
{
if (reader is DbDataReader dbReader) return ReadRowAsyncImplViaDbReader<T>(dbReader, type, row);

// no async API available; use non-async and fake it
return Task.FromResult(ReadRow<T>(type, row));
return deserializer.Func;
}

private async Task<T> ReadRowAsyncImplViaDbReader<T>(DbDataReader reader, Type type, Row row)
private async Task<T> ReadRowAsyncImpl<T>(Type type, Row row)
{
if (reader == null) throw new ObjectDisposedException(GetType().FullName, "The reader has been disposed; this can happen after all data has been consumed");
if (IsConsumed) throw new InvalidOperationException("Query results must be consumed in the correct order, and each result can only be consumed once");
Expand Down Expand Up @@ -229,7 +238,6 @@ private async Task<IEnumerable<T>> ReadBufferedAsync<T>(int index, Func<DbDataRe
{
try
{
var reader = (DbDataReader)this.reader;
var buffer = new List<T>();
while (index == gridIndex && await reader.ReadAsync(cancel).ConfigureAwait(false))
{
Expand All @@ -245,6 +253,64 @@ private async Task<IEnumerable<T>> ReadBufferedAsync<T>(int index, Func<DbDataRe
}
}
}

#if NET5_0_OR_GREATER
/// <summary>
/// Read the next grid of results.
/// </summary>
/// <typeparam name="T">The type to read.</typeparam>
public IAsyncEnumerable<T> ReadUnbufferedAsync<T>() => ReadAsyncUnbufferedImpl<T>(typeof(T));

private IAsyncEnumerable<T> ReadAsyncUnbufferedImpl<T>(Type type)
{
var deserializer = ValidateAndMarkConsumed(type);
return ReadUnbufferedAsync<T>(gridIndex, deserializer, cancel);
}

private async IAsyncEnumerable<T> ReadUnbufferedAsync<T>(int index, Func<DbDataReader, object> deserializer, [EnumeratorCancellation] CancellationToken cancel)
{
try
{
while (index == gridIndex && await reader.ReadAsync(cancel).ConfigureAwait(false))
{
yield return ConvertTo<T>(deserializer(reader));
}
}
finally // finally so that First etc progresses things even when multiple rows
{
if (index == gridIndex)
{
await NextResultAsync().ConfigureAwait(false);
}
}
}

/// <summary>
/// Dispose the grid, closing and disposing both the underlying reader and command.
/// </summary>
public async ValueTask DisposeAsync()
{
if (reader != null)
{
if (!reader.IsClosed) Command?.Cancel();
await reader.DisposeAsync();
reader = null;
}
if (Command != null)
{
if (Command is DbCommand typed)
{
await typed.DisposeAsync();
}
else
{
Command.Dispose();
}
Command = null;
}
GC.SuppressFinalize(this);
}
#endif
}
}
}
4 changes: 2 additions & 2 deletions Dapper/SqlMapper.GridReader.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
using System;
using System.Collections.Generic;
using System.Data;
using System.Linq;
using System.Data.Common;
using System.Globalization;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Data.Common;

namespace Dapper
{
Expand Down
10 changes: 7 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ Note: to get the latest pre-release build, add ` -Pre` to the end of the command

### unreleased

- add support for `SqlDecimal` and other types that need to be accessed via `DbDataReader.GetFieldValue<T>`
- add an overload of `AddTypeMap` that supports `DbDataReader.GetFieldValue<T>` for additional types
- acknowledge that in reality we only support `DbDataReader`; this has been true (via `DbConnection`) for `async` forever
- (#1910 via mgravell, fix #1907, #1263)
- add support for `SqlDecimal` and other types that need to be accessed via `DbDataReader.GetFieldValue<T>`
- add an overload of `AddTypeMap` that supports `DbDataReader.GetFieldValue<T>` for additional types
- acknowledge that in reality we only support `DbDataReader`; this has been true (via `DbConnection`) for `async` forever
- (#1912 via mgravell)
- add missing `AsyncEnumerable<T> QueryUnbufferedAsync<T>(...)` and `GridReader.ReadUnbufferedAsync<T>(...)` APIs (.NET 5 and later)
- implement `IAsyncDisposable` on `GridReader` (.NET 5 and later)

(note: new PRs will not be merged until they add release note wording here)

Expand Down
88 changes: 84 additions & 4 deletions tests/Dapper.Tests/AsyncTests.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
using System.Linq;
using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Diagnostics;
using System;
using System.Threading.Tasks;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
using System.Data.Common;
using Xunit.Abstractions;

namespace Dapper.Tests
Expand Down Expand Up @@ -45,6 +46,85 @@ public async Task TestBasicStringUsageAsync()
Assert.Equal(new[] { "abc", "def" }, arr);
}

#if NET5_0_OR_GREATER
[Fact]
public async Task TestBasicStringUsageUnbufferedAsync()
{
var results = new List<string>();
await foreach (var value in connection.QueryUnbufferedAsync<string>("select 'abc' as [Value] union all select @txt", new { txt = "def" })
.ConfigureAwait(false))
{
results.Add(value);
}
var arr = results.ToArray();
Assert.Equal(new[] { "abc", "def" }, arr);
}

[Fact]
public async Task TestBasicStringUsageUnbufferedAsync_Cancellation()
{
using var cts = new CancellationTokenSource();
var results = new List<string>();
await Assert.ThrowsAnyAsync<OperationCanceledException>(async () =>
{
await foreach (var value in connection.QueryUnbufferedAsync<string>("select 'abc' as [Value] union all select @txt", new { txt = "def" })
.ConfigureAwait(false).WithCancellation(cts.Token))
{
results.Add(value);
cts.Cancel(); // cancel after first item
}
});
var arr = results.ToArray();
Assert.Equal(new[] { "abc" }, arr); // we don't expect the "def" because of the cancellation
}

[Fact]
public async Task TestBasicStringUsageViaGridReaderUnbufferedAsync()
{
var results = new List<string>();
await using (var grid = await connection.QueryMultipleAsync("select 'abc' union select 'def'; select @txt", new { txt = "ghi" })
.ConfigureAwait(false))
{
while (!grid.IsConsumed)
{
await foreach (var value in grid.ReadUnbufferedAsync<string>()
.ConfigureAwait(false))
{
results.Add(value);
}
}
}
var arr = results.ToArray();
Assert.Equal(new[] { "abc", "def", "ghi" }, arr);
}

[Fact]
public async Task TestBasicStringUsageViaGridReaderUnbufferedAsync_Cancellation()
{
using var cts = new CancellationTokenSource();
var results = new List<string>();
await using (var grid = await connection.QueryMultipleAsync("select 'abc' union select 'def'; select @txt", new { txt = "ghi" })
.ConfigureAwait(false))
{
await Assert.ThrowsAnyAsync<OperationCanceledException>(async () =>
{
while (!grid.IsConsumed)
{
await foreach (var value in grid.ReadUnbufferedAsync<string>()
.ConfigureAwait(false)
.WithCancellation(cts.Token))
{
results.Add(value);
}
cts.Cancel();
}
});
}
var arr = results.ToArray();
Assert.Equal(new[] { "abc", "def" }, arr); // don't expect the ghi because of cancellation
}
#endif

[Fact]
public async Task TestBasicStringUsageQueryFirstAsync()
{
Expand Down

0 comments on commit d56340b

Please sign in to comment.