Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: Clean up some things in LiquidPromptTemplate and PromptyKernelExtensions #6118

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,22 @@ namespace SemanticKernel.Extensions.PromptTemplates.Liquid.UnitTests;

public class LiquidTemplateFactoryTest
{
[Fact]
public void ItThrowsExceptionForUnknownPromptTemplateFormat()
[Theory]
[InlineData("unknown-format")]
[InlineData(null)]
public void ItThrowsExceptionForUnknownPromptTemplateFormat(string? format)
{
// Arrange
var promptConfig = new PromptTemplateConfig("UnknownFormat")
{
TemplateFormat = "unknown-format",
TemplateFormat = format,
};

var target = new LiquidPromptTemplateFactory();

// Act & Assert
Assert.False(target.TryCreate(promptConfig, out IPromptTemplate? result));
Assert.Null(result);
Assert.Throws<KernelException>(() => target.Create(promptConfig));
}

Expand All @@ -38,7 +42,6 @@ public void ItCreatesLiquidPromptTemplate()
var result = target.Create(promptConfig);

// Assert
Assert.NotNull(result);
Assert.True(result is LiquidPromptTemplate);
Assert.IsType<LiquidPromptTemplate>(result);
}
}
111 changes: 62 additions & 49 deletions dotnet/src/Extensions/PromptTemplates.Liquid/LiquidPromptTemplate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading;
Expand All @@ -15,53 +16,70 @@ namespace Microsoft.SemanticKernel.PromptTemplates.Liquid;
/// </summary>
internal sealed class LiquidPromptTemplate : IPromptTemplate
{
private readonly PromptTemplateConfig _config;
private static readonly Regex s_roleRegex = new(@"(?<role>system|assistant|user|function):[\s]+");
private static readonly Regex s_roleRegex = new(@"(?<role>system|assistant|user|function):\s+", RegexOptions.Compiled);

/// <summary>
/// Constructor for Liquid PromptTemplate.
/// </summary>
private readonly Template _liquidTemplate;
private readonly Dictionary<string, object> _inputVariables;

/// <summary>Initializes the <see cref="LiquidPromptTemplate"/>.</summary>
/// <param name="config">Prompt template configuration</param>
/// <exception cref="ArgumentException">throw if <see cref="PromptTemplateConfig.TemplateFormat"/> is not <see cref="LiquidPromptTemplateFactory.LiquidTemplateFormat"/></exception>
/// <exception cref="ArgumentException"><see cref="PromptTemplateConfig.TemplateFormat"/> is not <see cref="LiquidPromptTemplateFactory.LiquidTemplateFormat"/>.</exception>
/// <exception cref="ArgumentException">The template in <paramref name="config"/> could not be parsed.</exception>
public LiquidPromptTemplate(PromptTemplateConfig config)
{
if (config.TemplateFormat != LiquidPromptTemplateFactory.LiquidTemplateFormat)
{
throw new ArgumentException($"Invalid template format: {config.TemplateFormat}");
}

this._config = config;
}
// Parse the template now so we can check for errors, understand variable usage, and
// avoid having to parse on each render.
this._liquidTemplate = Template.ParseLiquid(config.Template);
if (this._liquidTemplate.HasErrors)
{
throw new ArgumentException($"The template could not be parsed:{Environment.NewLine}{string.Join(Environment.NewLine, this._liquidTemplate.Messages)}");
}
Debug.Assert(this._liquidTemplate.Page is not null);

/// <inheritdoc/>
public Task<string> RenderAsync(Kernel kernel, KernelArguments? arguments = null, CancellationToken cancellationToken = default)
{
Verify.NotNull(kernel);
// TODO: Update config.InputVariables with any variables referenced by the template but that aren't explicitly defined in the front matter.

var template = this._config.Template;
var liquidTemplate = Template.ParseLiquid(template);
Dictionary<string, object> nonEmptyArguments = new();
foreach (var p in this._config.InputVariables)
// Configure _inputVariables with the default values from the config. This will be used
// in RenderAsync to seed the arguments used when evaluating the template.
this._inputVariables = [];
foreach (var p in config.InputVariables)
{
if (p.Default is null || (p.Default is string s && string.IsNullOrWhiteSpace(s)))
if (p.Default is not null)
{
continue;
this._inputVariables[p.Name] = p.Default;
}

nonEmptyArguments[p.Name] = p.Default;
}
}

/// <inheritdoc/>
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
public async Task<string> RenderAsync(Kernel kernel, KernelArguments? arguments = null, CancellationToken cancellationToken = default)
#pragma warning restore CS1998
{
Verify.NotNull(kernel);
cancellationToken.ThrowIfCancellationRequested();

foreach (var p in arguments ?? new KernelArguments())
Dictionary<string, object>? nonEmptyArguments = null;
if (this._inputVariables.Count is > 0 || arguments?.Count is > 0)
{
if (p.Value is null)
nonEmptyArguments = new(this._inputVariables);
if (arguments is not null)
{
continue;
foreach (var p in arguments)
{
if (p.Value is not null)
{
nonEmptyArguments[p.Key] = p.Value;
}
}
}

nonEmptyArguments[p.Key] = p.Value;
}

var renderedResult = liquidTemplate.Render(nonEmptyArguments);
var renderedResult = this._liquidTemplate.Render(nonEmptyArguments);

// parse chat history
// for every text like below
Expand All @@ -72,35 +90,30 @@ public Task<string> RenderAsync(Kernel kernel, KernelArguments? arguments = null
// <message role="system|assistant|user|function">
// xxxx
// </message>

var splits = s_roleRegex.Split(renderedResult);

// if no role is found, return the entire text
if (splits.Length == 1)
if (splits.Length > 1)
{
return Task.FromResult(renderedResult);
}
// otherwise, the split text chunks will be in the following format
// [0] = ""
// [1] = role information
// [2] = message content
// [3] = role information
// [4] = message content
// ...
// we will iterate through the array and create a new string with the following format
var sb = new StringBuilder();
for (var i = 1; i < splits.Length; i += 2)
{
sb.Append("<message role=\"").Append(splits[i]).AppendLine("\">");
sb.AppendLine(splits[i + 1]);
sb.AppendLine("</message>");
}

// otherwise, the split text chunks will be in the following format
// [0] = ""
// [1] = role information
// [2] = message content
// [3] = role information
// [4] = message content
// ...
// we will iterate through the array and create a new string with the following format
var sb = new StringBuilder();
for (var i = 1; i < splits.Length; i += 2)
{
var role = splits[i];
var content = splits[i + 1];
sb.Append("<message role=\"").Append(role).AppendLine("\">");
sb.AppendLine(content);
sb.AppendLine("</message>");
renderedResult = sb.ToString();
}

renderedResult = sb.ToString();

return Task.FromResult(renderedResult);
return renderedResult;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ public sealed class LiquidPromptTemplateFactory : IPromptTemplateFactory
/// <inheritdoc/>
public bool TryCreate(PromptTemplateConfig templateConfig, [NotNullWhen(true)] out IPromptTemplate? result)
{
if (templateConfig.TemplateFormat.Equals(LiquidTemplateFormat, StringComparison.Ordinal))
Verify.NotNull(templateConfig);

if (LiquidTemplateFormat.Equals(templateConfig.TemplateFormat, StringComparison.Ordinal))
{
result = new LiquidPromptTemplate(templateConfig);
return true;
Expand Down
107 changes: 93 additions & 14 deletions dotnet/src/Functions/Functions.Prompty.UnitTests/PromptyTest.cs
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.IO;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using Microsoft.SemanticKernel.TextGeneration;
using Xunit;

namespace SemanticKernel.Functions.Prompty.UnitTests;

public sealed class PromptyTest
{
[Fact]
public void ChatPromptyTest()
{
// Arrange
var kernel = Kernel.CreateBuilder()
.Build();

var cwd = Directory.GetCurrentDirectory();
var chatPromptyPath = Path.Combine(cwd, "TestData", "chat.prompty");
Kernel kernel = new();
var chatPromptyPath = Path.Combine("TestData", "chat.prompty");
var promptyTemplate = File.ReadAllText(chatPromptyPath);

// Act
Expand All @@ -34,11 +39,8 @@ public void ChatPromptyTest()
public void ChatPromptyShouldSupportCreatingOpenAIExecutionSettings()
{
// Arrange
var kernel = Kernel.CreateBuilder()
.Build();

var cwd = Directory.GetCurrentDirectory();
var chatPromptyPath = Path.Combine(cwd, "TestData", "chat.prompty");
Kernel kernel = new();
var chatPromptyPath = Path.Combine("TestData", "chat.prompty");

// Act
var kernelFunction = kernel.CreateFunctionFromPromptyFile(chatPromptyPath);
Expand Down Expand Up @@ -70,10 +72,8 @@ public void ChatPromptyShouldSupportCreatingOpenAIExecutionSettings()
public void ItShouldCreateFunctionFromPromptYamlWithNoExecutionSettings()
{
// Arrange
var kernel = Kernel.CreateBuilder()
.Build();
var cwd = Directory.GetCurrentDirectory();
var promptyPath = Path.Combine(cwd, "TestData", "chatNoExecutionSettings.prompty");
Kernel kernel = new();
var promptyPath = Path.Combine("TestData", "chatNoExecutionSettings.prompty");

// Act
var kernelFunction = kernel.CreateFunctionFromPromptyFile(promptyPath);
Expand All @@ -83,6 +83,85 @@ public void ItShouldCreateFunctionFromPromptYamlWithNoExecutionSettings()
Assert.Equal("prompty_with_no_execution_setting", kernelFunction.Name);
Assert.Equal("prompty without execution setting", kernelFunction.Description);
Assert.Single(kernelFunction.Metadata.Parameters);
Assert.Equal("prompt", kernelFunction.Metadata.Parameters[0].Name);
Assert.Empty(kernelFunction.ExecutionSettings!);
}

[Theory]
[InlineData("""
---
name: SomePrompt
---
Abc
""")]
[InlineData("""
---
name: SomePrompt
---
Abc
""")]
[InlineData("""
---a
name: SomePrompt
---
Abc
""")]
[InlineData("""
---
name: SomePrompt
---b
Abc
""")]
public void ItRequiresStringSeparatorPlacement(string prompt)
{
// Arrange
Kernel kernel = new();

// Act / Assert
Assert.Throws<ArgumentException>(() => kernel.CreateFunctionFromPrompty(prompt));
}

[Fact]
public async Task ItSupportsSeparatorInContentAsync()
{
// Arrange
IKernelBuilder builder = Kernel.CreateBuilder();
builder.Services.AddSingleton<ITextGenerationService>(_ => new EchoTextGenerationService());
Kernel kernel = builder.Build();

// Act
var kernelFunction = kernel.CreateFunctionFromPrompty("""
---
name: SomePrompt
description: This is the description.
---
Abc---def
---
Efg
""");

// Assert
Assert.NotNull(kernelFunction);
Assert.Equal("SomePrompt", kernelFunction.Name);
Assert.Equal("This is the description.", kernelFunction.Description);
Assert.Equal("""
Abc---def
---
Efg
""", await kernelFunction.InvokeAsync<string>(kernel));
}

private sealed class EchoTextGenerationService : ITextGenerationService
{
public IReadOnlyDictionary<string, object?> Attributes { get; } = new Dictionary<string, object?>();

public Task<IReadOnlyList<TextContent>> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) =>
Task.FromResult<IReadOnlyList<TextContent>>([new TextContent(prompt)]);

public async IAsyncEnumerable<StreamingTextContent> GetStreamingTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
await Task.Delay(0, cancellationToken);
yield return new StreamingTextContent(prompt);
}
}
}
Loading
Loading