Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SemanticKernel: Correcting non-standard way of working with PromptExecutionSettings #689

Merged
merged 14 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions LLama.Examples/Examples/SemanticKernelPrompt.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
using LLama.Common;
using LLamaSharp.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel;
using LLamaSharp.SemanticKernel.TextCompletion;
using Microsoft.SemanticKernel.TextGeneration;
using Microsoft.Extensions.DependencyInjection;
using LLamaSharp.SemanticKernel;

namespace LLama.Examples.Examples
{
Expand Down Expand Up @@ -31,7 +31,7 @@ public static async Task Run()

One line TLDR with the fewest words.";

ChatRequestSettings settings = new() { MaxTokens = 100 };
LLamaSharpPromptExecutionSettings settings = new() { MaxTokens = 100 };
var summarize = kernel.CreateFunctionFromPrompt(prompt, settings);

string text1 = @"
Expand Down
12 changes: 6 additions & 6 deletions LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion;
public sealed class LLamaSharpChatCompletion : IChatCompletionService
{
private readonly ILLamaExecutor _model;
private ChatRequestSettings defaultRequestSettings;
private LLamaSharpPromptExecutionSettings defaultRequestSettings;
private readonly IHistoryTransform historyTransform;
private readonly ITextStreamTransform outputTransform;

private readonly Dictionary<string, object?> _attributes = new();

public IReadOnlyDictionary<string, object?> Attributes => this._attributes;

static ChatRequestSettings GetDefaultSettings()
static LLamaSharpPromptExecutionSettings GetDefaultSettings()
{
return new ChatRequestSettings
return new LLamaSharpPromptExecutionSettings
{
MaxTokens = 256,
Temperature = 0,
Expand All @@ -37,7 +37,7 @@ static ChatRequestSettings GetDefaultSettings()
}

public LLamaSharpChatCompletion(ILLamaExecutor model,
ChatRequestSettings? defaultRequestSettings = default,
LLamaSharpPromptExecutionSettings? defaultRequestSettings = default,
IHistoryTransform? historyTransform = null,
ITextStreamTransform? outputTransform = null)
{
Expand Down Expand Up @@ -65,7 +65,7 @@ public ChatHistory CreateNewChat(string? instructions = "")
public async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
var settings = executionSettings != null
? ChatRequestSettings.FromRequestSettings(executionSettings)
? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings)
: defaultRequestSettings;
var prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory());

Expand All @@ -86,7 +86,7 @@ await foreach (var token in output)
public async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var settings = executionSettings != null
? ChatRequestSettings.FromRequestSettings(executionSettings)
? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings)
: defaultRequestSettings;
var prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory());

Expand Down
7 changes: 3 additions & 4 deletions LLama.SemanticKernel/ExtensionMethods.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using LLamaSharp.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.ChatCompletion;
namespace LLamaSharp.SemanticKernel;

public static class ExtensionMethods
Expand All @@ -23,11 +22,11 @@ public static class ExtensionMethods
}

/// <summary>
/// Convert ChatRequestSettings to LLamaSharp InferenceParams
/// Convert LLamaSharpPromptExecutionSettings to LLamaSharp InferenceParams
/// </summary>
/// <param name="requestSettings"></param>
/// <returns></returns>
internal static global::LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this ChatRequestSettings requestSettings)
internal static global::LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this LLamaSharpPromptExecutionSettings requestSettings)
{
if (requestSettings is null)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
using Microsoft.SemanticKernel;

/* Unmerged change from project 'LLamaSharp.SemanticKernel (netstandard2.0)'
Before:
using Microsoft.SemanticKernel;
After:
using LLamaSharp;
using LLamaSharp.SemanticKernel;
using LLamaSharp.SemanticKernel;
using LLamaSharp.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel;
*/
using LLamaSharp.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace LLamaSharp.SemanticKernel.ChatCompletion;
namespace LLamaSharp.SemanticKernel;

public class ChatRequestSettings : PromptExecutionSettings
public class LLamaSharpPromptExecutionSettings : PromptExecutionSettings
{
AsakusaRinne marked this conversation as resolved.
Show resolved Hide resolved
/// <summary>
/// Temperature controls the randomness of the completion.
Expand Down Expand Up @@ -62,36 +74,42 @@ public class ChatRequestSettings : PromptExecutionSettings
[JsonPropertyName("token_selection_biases")]
public IDictionary<int, int> TokenSelectionBiases { get; set; } = new Dictionary<int, int>();

/// <summary>
/// Indicates the format of the response which can be used downstream to post-process the messages. Handlebars: handlebars_object. JSON: json_object, etc.
/// </summary>
[JsonPropertyName("response_format")]
public string ResponseFormat { get; set; } = string.Empty;

/// <summary>
/// Create a new settings object with the values from another settings object.
/// </summary>
/// <param name="requestSettings">Template configuration</param>
/// <param name="defaultMaxTokens">Default max tokens</param>
/// <returns>An instance of OpenAIRequestSettings</returns>
public static ChatRequestSettings FromRequestSettings(PromptExecutionSettings? requestSettings, int? defaultMaxTokens = null)
public static LLamaSharpPromptExecutionSettings FromRequestSettings(PromptExecutionSettings? requestSettings, int? defaultMaxTokens = null)
{
if (requestSettings is null)
{
return new ChatRequestSettings()
return new LLamaSharpPromptExecutionSettings()
{
MaxTokens = defaultMaxTokens
};
}

if (requestSettings is ChatRequestSettings requestSettingsChatRequestSettings)
if (requestSettings is LLamaSharpPromptExecutionSettings requestSettingsChatRequestSettings)
{
return requestSettingsChatRequestSettings;
}

var json = JsonSerializer.Serialize(requestSettings);
var chatRequestSettings = JsonSerializer.Deserialize<ChatRequestSettings>(json, s_options);
var chatRequestSettings = JsonSerializer.Deserialize<LLamaSharpPromptExecutionSettings>(json, s_options);

if (chatRequestSettings is not null)
{
return chatRequestSettings;
}

throw new ArgumentException($"Invalid request settings, cannot convert to {nameof(ChatRequestSettings)}", nameof(requestSettings));
throw new ArgumentException($"Invalid request settings, cannot convert to {nameof(LLamaSharpPromptExecutionSettings)}", nameof(requestSettings));
}

private static readonly JsonSerializerOptions s_options = CreateOptions();
Expand All @@ -105,7 +123,7 @@ private static JsonSerializerOptions CreateOptions()
AllowTrailingCommas = true,
PropertyNameCaseInsensitive = true,
ReadCommentHandling = JsonCommentHandling.Skip,
Converters = { new ChatRequestSettingsConverter() }
Converters = { new LLamaSharpPromptExecutionSettingsConverter() }
};

return options;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
using System.Text.Json;
using System.Text.Json.Serialization;

namespace LLamaSharp.SemanticKernel.ChatCompletion;
namespace LLamaSharp.SemanticKernel;

/// <summary>
/// JSON converter for <see cref="OpenAIRequestSettings"/>
/// </summary>
public class ChatRequestSettingsConverter : JsonConverter<ChatRequestSettings>
public class LLamaSharpPromptExecutionSettingsConverter : JsonConverter<LLamaSharpPromptExecutionSettings>
{
/// <inheritdoc/>
public override ChatRequestSettings? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
public override LLamaSharpPromptExecutionSettings? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var requestSettings = new ChatRequestSettings();
var requestSettings = new LLamaSharpPromptExecutionSettings();

while (reader.Read() && reader.TokenType != JsonTokenType.EndObject)
{
Expand Down Expand Up @@ -77,7 +77,7 @@ public class ChatRequestSettingsConverter : JsonConverter<ChatRequestSettings>
}

/// <inheritdoc/>
public override void Write(Utf8JsonWriter writer, ChatRequestSettings value, JsonSerializerOptions options)
public override void Write(Utf8JsonWriter writer, LLamaSharpPromptExecutionSettings value, JsonSerializerOptions options)
{
writer.WriteStartObject();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using LLama.Abstractions;
using LLamaSharp.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Services;
using Microsoft.SemanticKernel.TextGeneration;
Expand All @@ -24,7 +23,7 @@ public LLamaSharpTextCompletion(ILLamaExecutor executor)
/// <inheritdoc/>
public async Task<IReadOnlyList<TextContent>> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
var settings = ChatRequestSettings.FromRequestSettings(executionSettings);
var settings = LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings);
var result = executor.InferAsync(prompt, settings?.ToLLamaSharpInferenceParams(), cancellationToken);
var sb = new StringBuilder();
await foreach (var token in result)
Expand All @@ -37,7 +36,7 @@ await foreach (var token in result)
/// <inheritdoc/>
public async IAsyncEnumerable<StreamingTextContent> GetStreamingTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var settings = ChatRequestSettings.FromRequestSettings(executionSettings);
var settings = LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings);
var result = executor.InferAsync(prompt, settings?.ToLLamaSharpInferenceParams(), cancellationToken);
await foreach (var token in result)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LLamaSharp.SemanticKernel.ChatCompletion;
using LLamaSharp.SemanticKernel;
using LLamaSharp.SemanticKernel.ChatCompletion;
using System.Text.Json;

namespace LLama.Unittest.SemanticKernel
Expand All @@ -10,11 +11,11 @@ public void ChatRequestSettingsConverter_DeserializeWithDefaults()
{
// Arrange
var options = new JsonSerializerOptions();
options.Converters.Add(new ChatRequestSettingsConverter());
options.Converters.Add(new LLamaSharpPromptExecutionSettingsConverter());
var json = "{}";

// Act
var requestSettings = JsonSerializer.Deserialize<ChatRequestSettings>(json, options);
var requestSettings = JsonSerializer.Deserialize<LLamaSharpPromptExecutionSettings>(json, options);

// Assert
Assert.NotNull(requestSettings);
Expand All @@ -36,7 +37,7 @@ public void ChatRequestSettingsConverter_DeserializeWithSnakeCase()
// Arrange
var options = new JsonSerializerOptions();
options.AllowTrailingCommas = true;
options.Converters.Add(new ChatRequestSettingsConverter());
options.Converters.Add(new LLamaSharpPromptExecutionSettingsConverter());
var json = @"{
""frequency_penalty"": 0.5,
""max_tokens"": 250,
Expand All @@ -49,7 +50,7 @@ public void ChatRequestSettingsConverter_DeserializeWithSnakeCase()
}";

// Act
var requestSettings = JsonSerializer.Deserialize<ChatRequestSettings>(json, options);
var requestSettings = JsonSerializer.Deserialize<LLamaSharpPromptExecutionSettings>(json, options);

// Assert
Assert.NotNull(requestSettings);
Expand All @@ -73,7 +74,7 @@ public void ChatRequestSettingsConverter_DeserializeWithPascalCase()
// Arrange
var options = new JsonSerializerOptions();
options.AllowTrailingCommas = true;
options.Converters.Add(new ChatRequestSettingsConverter());
options.Converters.Add(new LLamaSharpPromptExecutionSettingsConverter());
var json = @"{
""FrequencyPenalty"": 0.5,
""MaxTokens"": 250,
Expand All @@ -86,7 +87,7 @@ public void ChatRequestSettingsConverter_DeserializeWithPascalCase()
}";

// Act
var requestSettings = JsonSerializer.Deserialize<ChatRequestSettings>(json, options);
var requestSettings = JsonSerializer.Deserialize<LLamaSharpPromptExecutionSettings>(json, options);

// Assert
Assert.NotNull(requestSettings);
Expand Down
16 changes: 8 additions & 8 deletions LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LLamaSharp.SemanticKernel.ChatCompletion;
using LLamaSharp.SemanticKernel;
using Microsoft.SemanticKernel;

namespace LLama.Unittest.SemanticKernel
Expand All @@ -10,7 +10,7 @@ public void ChatRequestSettings_FromRequestSettingsNull()
{
// Arrange
// Act
var requestSettings = ChatRequestSettings.FromRequestSettings(null, null);
var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(null, null);

// Assert
Assert.NotNull(requestSettings);
Expand All @@ -31,7 +31,7 @@ public void ChatRequestSettings_FromRequestSettingsNullWithMaxTokens()
{
// Arrange
// Act
var requestSettings = ChatRequestSettings.FromRequestSettings(null, 200);
var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(null, 200);

// Assert
Assert.NotNull(requestSettings);
Expand All @@ -51,7 +51,7 @@ public void ChatRequestSettings_FromRequestSettingsNullWithMaxTokens()
public void ChatRequestSettings_FromExistingRequestSettings()
{
// Arrange
var originalRequestSettings = new ChatRequestSettings()
var originalRequestSettings = new LLamaSharpPromptExecutionSettings()
{
FrequencyPenalty = 0.5,
MaxTokens = 100,
Expand All @@ -64,7 +64,7 @@ public void ChatRequestSettings_FromExistingRequestSettings()
};

// Act
var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings);
var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(originalRequestSettings);

// Assert
Assert.NotNull(requestSettings);
Expand All @@ -81,7 +81,7 @@ public void ChatRequestSettings_FromAIRequestSettings()
};

// Act
var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings);
var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(originalRequestSettings);

// Assert
Assert.NotNull(requestSettings);
Expand Down Expand Up @@ -109,7 +109,7 @@ public void ChatRequestSettings_FromAIRequestSettingsWithExtraPropertiesInSnakeC
};

// Act
var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings);
var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(originalRequestSettings);

// Assert
Assert.NotNull(requestSettings);
Expand Down Expand Up @@ -148,7 +148,7 @@ public void ChatRequestSettings_FromAIRequestSettingsWithExtraPropertiesInPascal
};

// Act
var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings);
var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(originalRequestSettings);

// Assert
Assert.NotNull(requestSettings);
Expand Down
2 changes: 1 addition & 1 deletion LLama.Unittest/SemanticKernel/ExtensionMethodsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public void ToLLamaSharpChatHistory_StateUnderTest_ExpectedBehavior()
public void ToLLamaSharpInferenceParams_StateUnderTest_ExpectedBehavior()
{
// Arrange
var requestSettings = new ChatRequestSettings();
var requestSettings = new LLamaSharpPromptExecutionSettings();

// Act
var result = ExtensionMethods.ToLLamaSharpInferenceParams(
Expand Down