Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch upsert support for SQL Server Memory DB #489

Merged
merged 7 commits into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion extensions/SQLServer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ To run Kernel Memory service with SQL Server backend:
.Build();
```


## Local tests with Docker

You can test the connector locally with Docker:
Expand All @@ -84,3 +83,25 @@ For more information about the SQL Server Linux container:

- https://learn.microsoft.com/sql/linux/quickstart-install-connect-docker
- https://devblogs.microsoft.com/azure-sql/development-with-sql-in-containers-on-macos/

## Batch Upsert Feature

The SQL Server Memory DB now supports batching upsert operations, enhancing performance for bulk data operations. This feature allows for efficient insertion or updating of multiple records in a single operation.

### Using Batch Upsert

To use the batch upsert feature, you can utilize the `BatchUpsertAsync` method. This method accepts an index name and an enumerable of `MemoryRecord` objects, performing upsert operations for all provided records in a batch.

Example:

```csharp
var records = new List<MemoryRecord>
{
new MemoryRecord("id1", new Dictionary<string, object> { { "key", "value1" } }),
new MemoryRecord("id2", new Dictionary<string, object> { { "key", "value2" } })
};

await memory.BatchUpsertAsync("yourIndexName", records);
```

This method efficiently handles the insertion or updating of records, significantly improving performance for operations involving large datasets.
158 changes: 85 additions & 73 deletions extensions/SQLServer/SQLServer/SqlServerMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ namespace Microsoft.KernelMemory.MemoryDb.SQLServer;
/// <summary>
/// Represents a memory store implementation that uses a SQL Server database as its backing store.
/// </summary>
[System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities",
Justification = "We need to build the full table name using schema and collection, it does not support parameterized passing.")]
public class SqlServerMemory : IMemoryDb
#pragma warning disable CA2100 // SQL reviewed for user input validation
public class SqlServerMemory : IMemoryDb, IMemoryDbBatchUpsert
dluc marked this conversation as resolved.
Show resolved Hide resolved
{
/// <summary>
/// The SQL Server configuration.
Expand Down Expand Up @@ -366,91 +365,104 @@ INNER JOIN

/// <inheritdoc/>
public async Task<string> UpsertAsync(string index, MemoryRecord record, CancellationToken cancellationToken = default)
{
await foreach (var item in this.BatchUpsertAsync(index, new[] { record }, cancellationToken).ConfigureAwait(false))
{
return item;
}

return null!;
}

/// <inheritdoc/>
public async IAsyncEnumerable<string> BatchUpsertAsync(string index, IEnumerable<MemoryRecord> records, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
index = NormalizeIndexName(index);

if (!(await this.DoesIndexExistsAsync(index, cancellationToken).ConfigureAwait(false)))
{
// Index does not exist
return string.Empty;
throw new IndexNotFoundException($"The index '{index}' does not exist.");
}

using var connection = new SqlConnection(this._config.ConnectionString);
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);

using (SqlCommand command = connection.CreateCommand())
foreach (var record in records)
{
command.CommandText = $@"
BEGIN TRANSACTION;

MERGE INTO {this.GetFullTableName(this._config.MemoryTableName)}
USING (SELECT @key) as [src]([key])
ON {this.GetFullTableName(this._config.MemoryTableName)}.[key] = [src].[key]
WHEN MATCHED THEN
UPDATE SET payload=@payload, embedding=@embedding, tags=@tags
WHEN NOT MATCHED THEN
INSERT ([id], [key], [collection], [payload], [tags], [embedding])
VALUES (NEWID(), @key, @index, @payload, @tags, @embedding);

MERGE {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")} AS [tgt]
USING (
SELECT
{this.GetFullTableName(this._config.MemoryTableName)}.[id],
cast([vector].[key] AS INT) AS [vector_value_id],
cast([vector].[value] AS FLOAT) AS [vector_value]
FROM {this.GetFullTableName(this._config.MemoryTableName)}
CROSS APPLY
openjson(@embedding) [vector]
WHERE {this.GetFullTableName(this._config.MemoryTableName)}.[key] = @key
AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index
) AS [src]
ON [tgt].[memory_id] = [src].[id] AND [tgt].[vector_value_id] = [src].[vector_value_id]
WHEN MATCHED THEN
UPDATE SET [tgt].[vector_value] = [src].[vector_value]
WHEN NOT MATCHED THEN
INSERT ([memory_id], [vector_value_id], [vector_value])
VALUES ([src].[id],
[src].[vector_value_id],
[src].[vector_value] );

DELETE FROM [tgt]
FROM {this.GetFullTableName($"{this._config.TagsTableName}_{index}")} AS [tgt]
INNER JOIN {this.GetFullTableName(this._config.MemoryTableName)} ON [tgt].[memory_id] = {this.GetFullTableName(this._config.MemoryTableName)}.[id]
WHERE {this.GetFullTableName(this._config.MemoryTableName)}.[key] = @key
AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index;

MERGE {this.GetFullTableName($"{this._config.TagsTableName}_{index}")} AS [tgt]
USING (
SELECT
{this.GetFullTableName(this._config.MemoryTableName)}.[id],
cast([tags].[key] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS AS [tag_name],
[tag_value].[value] AS [value]
FROM {this.GetFullTableName(this._config.MemoryTableName)}
CROSS APPLY openjson(@tags) [tags]
CROSS APPLY openjson(cast([tags].[value] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS) [tag_value]
using (SqlCommand command = connection.CreateCommand())
{
command.CommandText = $@"
BEGIN TRANSACTION;

MERGE INTO {this.GetFullTableName(this._config.MemoryTableName)}
USING (SELECT @key) as [src]([key])
ON {this.GetFullTableName(this._config.MemoryTableName)}.[key] = [src].[key]
WHEN MATCHED THEN
UPDATE SET payload=@payload, embedding=@embedding, tags=@tags
WHEN NOT MATCHED THEN
INSERT ([id], [key], [collection], [payload], [tags], [embedding])
VALUES (NEWID(), @key, @index, @payload, @tags, @embedding);

MERGE {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")} AS [tgt]
USING (
SELECT
{this.GetFullTableName(this._config.MemoryTableName)}.[id],
cast([vector].[key] AS INT) AS [vector_value_id],
cast([vector].[value] AS FLOAT) AS [vector_value]
FROM {this.GetFullTableName(this._config.MemoryTableName)}
CROSS APPLY
openjson(@embedding) [vector]
WHERE {this.GetFullTableName(this._config.MemoryTableName)}.[key] = @key
AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index
) AS [src]
ON [tgt].[memory_id] = [src].[id] AND [tgt].[vector_value_id] = [src].[vector_value_id]
WHEN MATCHED THEN
UPDATE SET [tgt].[vector_value] = [src].[vector_value]
WHEN NOT MATCHED THEN
INSERT ([memory_id], [vector_value_id], [vector_value])
VALUES ([src].[id],
[src].[vector_value_id],
[src].[vector_value] );

DELETE FROM [tgt]
FROM {this.GetFullTableName($"{this._config.TagsTableName}_{index}")} AS [tgt]
INNER JOIN {this.GetFullTableName(this._config.MemoryTableName)} ON [tgt].[memory_id] = {this.GetFullTableName(this._config.MemoryTableName)}.[id]
WHERE {this.GetFullTableName(this._config.MemoryTableName)}.[key] = @key
AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index
) AS [src]
ON [tgt].[memory_id] = [src].[id] AND [tgt].[name] = [src].[tag_name]
WHEN MATCHED THEN
UPDATE SET [tgt].[value] = [src].[value]
WHEN NOT MATCHED THEN
INSERT ([memory_id], [name], [value])
VALUES ([src].[id],
[src].[tag_name],
[src].[value]);

COMMIT;";
AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index;

MERGE {this.GetFullTableName($"{this._config.TagsTableName}_{index}")} AS [tgt]
USING (
SELECT
{this.GetFullTableName(this._config.MemoryTableName)}.[id],
cast([tags].[key] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS AS [tag_name],
[tag_value].[value] AS [value]
FROM {this.GetFullTableName(this._config.MemoryTableName)}
CROSS APPLY openjson(@tags) [tags]
CROSS APPLY openjson(cast([tags].[value] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS) [tag_value]
WHERE {this.GetFullTableName(this._config.MemoryTableName)}.[key] = @key
AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index
) AS [src]
ON [tgt].[memory_id] = [src].[id] AND [tgt].[name] = [src].[tag_name]
WHEN MATCHED THEN
UPDATE SET [tgt].[value] = [src].[value]
WHEN NOT MATCHED THEN
INSERT ([memory_id], [name], [value])
VALUES ([src].[id],
[src].[tag_name],
[src].[value]);

command.Parameters.AddWithValue("@index", index);
command.Parameters.AddWithValue("@key", record.Id);
command.Parameters.AddWithValue("@payload", JsonSerializer.Serialize(record.Payload) ?? (object)DBNull.Value);
command.Parameters.AddWithValue("@tags", JsonSerializer.Serialize(record.Tags) ?? (object)DBNull.Value);
command.Parameters.AddWithValue("@embedding", JsonSerializer.Serialize(record.Vector.Data.ToArray()));
COMMIT;";

await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
command.Parameters.AddWithValue("@index", index);
command.Parameters.AddWithValue("@key", record.Id);
command.Parameters.AddWithValue("@payload", JsonSerializer.Serialize(record.Payload) ?? (object)DBNull.Value);
command.Parameters.AddWithValue("@tags", JsonSerializer.Serialize(record.Tags) ?? (object)DBNull.Value);
command.Parameters.AddWithValue("@embedding", JsonSerializer.Serialize(record.Vector.Data.ToArray()));

return record.Id;
await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);

yield return record.Id;
}
}
}

Expand Down
Loading