Skip to content

Commit

Permalink
Merge pull request #689 from zsogitbe/master
Browse files Browse the repository at this point in the history
SemanticKernel: Correcting non-standard way of working with PromptExecutionSettings
  • Loading branch information
AsakusaRinne committed Apr 30, 2024
2 parents 0c770a5 + 54c01d4 commit 6bf010d
Show file tree
Hide file tree
Showing 11 changed files with 267 additions and 31 deletions.
4 changes: 2 additions & 2 deletions LLama.Examples/Examples/SemanticKernelPrompt.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
using LLama.Common;
using LLamaSharp.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel;
using LLamaSharp.SemanticKernel.TextCompletion;
using Microsoft.SemanticKernel.TextGeneration;
using Microsoft.Extensions.DependencyInjection;
using LLamaSharp.SemanticKernel;

namespace LLama.Examples.Examples
{
Expand Down Expand Up @@ -31,7 +31,7 @@ public static async Task Run()
One line TLDR with the fewest words.";

ChatRequestSettings settings = new() { MaxTokens = 100 };
LLamaSharpPromptExecutionSettings settings = new() { MaxTokens = 100 };
var summarize = kernel.CreateFunctionFromPrompt(prompt, settings);

string text1 = @"
Expand Down
1 change: 1 addition & 0 deletions LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

namespace LLamaSharp.SemanticKernel.ChatCompletion;

[Obsolete("Use LLamaSharpPromptExecutionSettings instead")]
public class ChatRequestSettings : PromptExecutionSettings
{
/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion;
/// <summary>
/// JSON converter for <see cref="OpenAIRequestSettings"/>
/// </summary>
[Obsolete("Use LLamaSharpPromptExecutionSettingsConverter instead")]
public class ChatRequestSettingsConverter : JsonConverter<ChatRequestSettings>
{
/// <inheritdoc/>
Expand Down
12 changes: 6 additions & 6 deletions LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion;
public sealed class LLamaSharpChatCompletion : IChatCompletionService
{
private readonly ILLamaExecutor _model;
private ChatRequestSettings defaultRequestSettings;
private LLamaSharpPromptExecutionSettings defaultRequestSettings;
private readonly IHistoryTransform historyTransform;
private readonly ITextStreamTransform outputTransform;

Expand All @@ -27,9 +27,9 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService

public IReadOnlyDictionary<string, object?> Attributes => this._attributes;

static ChatRequestSettings GetDefaultSettings()
static LLamaSharpPromptExecutionSettings GetDefaultSettings()
{
return new ChatRequestSettings
return new LLamaSharpPromptExecutionSettings
{
MaxTokens = 256,
Temperature = 0,
Expand All @@ -39,7 +39,7 @@ static ChatRequestSettings GetDefaultSettings()
}

public LLamaSharpChatCompletion(ILLamaExecutor model,
ChatRequestSettings? defaultRequestSettings = default,
LLamaSharpPromptExecutionSettings? defaultRequestSettings = default,
IHistoryTransform? historyTransform = null,
ITextStreamTransform? outputTransform = null)
{
Expand Down Expand Up @@ -68,7 +68,7 @@ public ChatHistory CreateNewChat(string? instructions = "")
public async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
var settings = executionSettings != null
? ChatRequestSettings.FromRequestSettings(executionSettings)
? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings)
: defaultRequestSettings;

string prompt = this._getFormattedPrompt(chatHistory);
Expand All @@ -89,7 +89,7 @@ await foreach (var token in output)
public async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var settings = executionSettings != null
? ChatRequestSettings.FromRequestSettings(executionSettings)
? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings)
: defaultRequestSettings;

string prompt = this._getFormattedPrompt(chatHistory);
Expand Down
7 changes: 3 additions & 4 deletions LLama.SemanticKernel/ExtensionMethods.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using LLamaSharp.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.ChatCompletion;
namespace LLamaSharp.SemanticKernel;

public static class ExtensionMethods
Expand All @@ -23,11 +22,11 @@ public static class ExtensionMethods
}

/// <summary>
/// Convert ChatRequestSettings to LLamaSharp InferenceParams
/// Convert LLamaSharpPromptExecutionSettings to LLamaSharp InferenceParams
/// </summary>
/// <param name="requestSettings"></param>
/// <returns></returns>
internal static global::LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this ChatRequestSettings requestSettings)
internal static global::LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this LLamaSharpPromptExecutionSettings requestSettings)
{
if (requestSettings is null)
{
Expand Down
131 changes: 131 additions & 0 deletions LLama.SemanticKernel/LLamaSharpPromptExecutionSettings.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@

/* Unmerged change from project 'LLamaSharp.SemanticKernel (netstandard2.0)'
Before:
using Microsoft.SemanticKernel;
After:
using LLamaSharp;
using LLamaSharp.SemanticKernel;
using LLamaSharp.SemanticKernel;
using LLamaSharp.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel;
*/
using LLamaSharp.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace LLamaSharp.SemanticKernel;

public class LLamaSharpPromptExecutionSettings : PromptExecutionSettings
{
/// <summary>
/// Temperature controls the randomness of the completion.
/// The higher the temperature, the more random the completion.
/// </summary>
[JsonPropertyName("temperature")]
public double Temperature { get; set; } = 0;

/// <summary>
/// TopP controls the diversity of the completion.
/// The higher the TopP, the more diverse the completion.
/// </summary>
[JsonPropertyName("top_p")]
public double TopP { get; set; } = 0;

/// <summary>
/// Number between -2.0 and 2.0. Positive values penalize new tokens
/// based on whether they appear in the text so far, increasing the
/// model's likelihood to talk about new topics.
/// </summary>
[JsonPropertyName("presence_penalty")]
public double PresencePenalty { get; set; } = 0;

/// <summary>
/// Number between -2.0 and 2.0. Positive values penalize new tokens
/// based on their existing frequency in the text so far, decreasing
/// the model's likelihood to repeat the same line verbatim.
/// </summary>
[JsonPropertyName("frequency_penalty")]
public double FrequencyPenalty { get; set; } = 0;

/// <summary>
/// Sequences where the completion will stop generating further tokens.
/// </summary>
[JsonPropertyName("stop_sequences")]
public IList<string> StopSequences { get; set; } = Array.Empty<string>();

/// <summary>
/// How many completions to generate for each prompt. Default is 1.
/// Note: Because this parameter generates many completions, it can quickly consume your token quota.
/// Use carefully and ensure that you have reasonable settings for max_tokens and stop.
/// </summary>
[JsonPropertyName("results_per_prompt")]
public int ResultsPerPrompt { get; set; } = 1;

/// <summary>
/// The maximum number of tokens to generate in the completion.
/// </summary>
[JsonPropertyName("max_tokens")]
public int? MaxTokens { get; set; }

/// <summary>
/// Modify the likelihood of specified tokens appearing in the completion.
/// </summary>
[JsonPropertyName("token_selection_biases")]
public IDictionary<int, int> TokenSelectionBiases { get; set; } = new Dictionary<int, int>();

/// <summary>
/// Indicates the format of the response which can be used downstream to post-process the messages. Handlebars: handlebars_object. JSON: json_object, etc.
/// </summary>
[JsonPropertyName("response_format")]
public string ResponseFormat { get; set; } = string.Empty;

/// <summary>
/// Create a new settings object with the values from another settings object.
/// </summary>
/// <param name="requestSettings">Template configuration</param>
/// <param name="defaultMaxTokens">Default max tokens</param>
/// <returns>An instance of OpenAIRequestSettings</returns>
public static LLamaSharpPromptExecutionSettings FromRequestSettings(PromptExecutionSettings? requestSettings, int? defaultMaxTokens = null)
{
if (requestSettings is null)
{
return new LLamaSharpPromptExecutionSettings()
{
MaxTokens = defaultMaxTokens
};
}

if (requestSettings is LLamaSharpPromptExecutionSettings requestSettingsChatRequestSettings)
{
return requestSettingsChatRequestSettings;
}

var json = JsonSerializer.Serialize(requestSettings);
var chatRequestSettings = JsonSerializer.Deserialize<LLamaSharpPromptExecutionSettings>(json, s_options);

if (chatRequestSettings is not null)
{
return chatRequestSettings;
}

throw new ArgumentException($"Invalid request settings, cannot convert to {nameof(LLamaSharpPromptExecutionSettings)}", nameof(requestSettings));
}

private static readonly JsonSerializerOptions s_options = CreateOptions();

private static JsonSerializerOptions CreateOptions()
{
JsonSerializerOptions options = new()
{
WriteIndented = true,
MaxDepth = 20,
AllowTrailingCommas = true,
PropertyNameCaseInsensitive = true,
ReadCommentHandling = JsonCommentHandling.Skip,
Converters = { new LLamaSharpPromptExecutionSettingsConverter() }
};

return options;
}
}
104 changes: 104 additions & 0 deletions LLama.SemanticKernel/LLamaSharpPromptExecutionSettingsConverter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
using System;
using System.Collections.Generic;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace LLamaSharp.SemanticKernel;

/// <summary>
/// JSON converter for <see cref="OpenAIRequestSettings"/>
/// </summary>
public class LLamaSharpPromptExecutionSettingsConverter : JsonConverter<LLamaSharpPromptExecutionSettings>
{
/// <inheritdoc/>
public override LLamaSharpPromptExecutionSettings? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var requestSettings = new LLamaSharpPromptExecutionSettings();

while (reader.Read() && reader.TokenType != JsonTokenType.EndObject)
{
if (reader.TokenType == JsonTokenType.PropertyName)
{
string? propertyName = reader.GetString();

if (propertyName is not null)
{
// normalise property name to uppercase
propertyName = propertyName.ToUpperInvariant();
}

reader.Read();

switch (propertyName)
{
case "MODELID":
case "MODEL_ID":
requestSettings.ModelId = reader.GetString();
break;
case "TEMPERATURE":
requestSettings.Temperature = reader.GetDouble();
break;
case "TOPP":
case "TOP_P":
requestSettings.TopP = reader.GetDouble();
break;
case "FREQUENCYPENALTY":
case "FREQUENCY_PENALTY":
requestSettings.FrequencyPenalty = reader.GetDouble();
break;
case "PRESENCEPENALTY":
case "PRESENCE_PENALTY":
requestSettings.PresencePenalty = reader.GetDouble();
break;
case "MAXTOKENS":
case "MAX_TOKENS":
requestSettings.MaxTokens = reader.GetInt32();
break;
case "STOPSEQUENCES":
case "STOP_SEQUENCES":
requestSettings.StopSequences = JsonSerializer.Deserialize<IList<string>>(ref reader, options) ?? Array.Empty<string>();
break;
case "RESULTSPERPROMPT":
case "RESULTS_PER_PROMPT":
requestSettings.ResultsPerPrompt = reader.GetInt32();
break;
case "TOKENSELECTIONBIASES":
case "TOKEN_SELECTION_BIASES":
requestSettings.TokenSelectionBiases = JsonSerializer.Deserialize<IDictionary<int, int>>(ref reader, options) ?? new Dictionary<int, int>();
break;
default:
reader.Skip();
break;
}
}
}

return requestSettings;
}

/// <inheritdoc/>
public override void Write(Utf8JsonWriter writer, LLamaSharpPromptExecutionSettings value, JsonSerializerOptions options)
{
writer.WriteStartObject();

writer.WriteNumber("temperature", value.Temperature);
writer.WriteNumber("top_p", value.TopP);
writer.WriteNumber("frequency_penalty", value.FrequencyPenalty);
writer.WriteNumber("presence_penalty", value.PresencePenalty);
if (value.MaxTokens is null)
{
writer.WriteNull("max_tokens");
}
else
{
writer.WriteNumber("max_tokens", (decimal)value.MaxTokens);
}
writer.WritePropertyName("stop_sequences");
JsonSerializer.Serialize(writer, value.StopSequences, options);
writer.WriteNumber("results_per_prompt", value.ResultsPerPrompt);
writer.WritePropertyName("token_selection_biases");
JsonSerializer.Serialize(writer, value.TokenSelectionBiases, options);

writer.WriteEndObject();
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using LLama.Abstractions;
using LLamaSharp.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Services;
using Microsoft.SemanticKernel.TextGeneration;
Expand All @@ -24,7 +23,7 @@ public LLamaSharpTextCompletion(ILLamaExecutor executor)
/// <inheritdoc/>
public async Task<IReadOnlyList<TextContent>> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
var settings = ChatRequestSettings.FromRequestSettings(executionSettings);
var settings = LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings);
var result = executor.InferAsync(prompt, settings?.ToLLamaSharpInferenceParams(), cancellationToken);
var sb = new StringBuilder();
await foreach (var token in result)
Expand All @@ -37,7 +36,7 @@ await foreach (var token in result)
/// <inheritdoc/>
public async IAsyncEnumerable<StreamingTextContent> GetStreamingTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var settings = ChatRequestSettings.FromRequestSettings(executionSettings);
var settings = LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings);
var result = executor.InferAsync(prompt, settings?.ToLLamaSharpInferenceParams(), cancellationToken);
await foreach (var token in result)
{
Expand Down

0 comments on commit 6bf010d

Please sign in to comment.