Skip to content

Commit

Permalink
feat: BaseRetriever, VectorStore, VectorStoreRetriever and Chroma (#42)
Browse files Browse the repository at this point in the history
* vectorstore and chroma
* tests
* fix rebase artefacts
* tests
* supress local tests and primitive serialize/deserialize
* converter see cref

---------

Co-authored-by: Evgenii Khoroshev <[email protected]>
Co-authored-by: Konstantin S <[email protected]>
  • Loading branch information
3 people authored Nov 3, 2023
1 parent d7d1544 commit 4ab15b9
Show file tree
Hide file tree
Showing 32 changed files with 1,350 additions and 28 deletions.
7 changes: 7 additions & 0 deletions LangChain.sln
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LangChain.Sources.WebBase.I
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LangChain.Core.UnitTests", "src\tests\LangChain.Core.UnitTests\LangChain.Core.UnitTests.csproj", "{91CCC7E4-70E2-4589-8F7A-9B5BA2844DD1}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LangChain.Databases.Chroma.IntegrationTests", "src\tests\LangChain.Databases.Chroma.IntegrationTests\LangChain.Databases.Chroma.IntegrationTests.csproj", "{302CD326-ADC3-484E-8F41-A54934A01D70}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -298,6 +300,10 @@ Global
{91CCC7E4-70E2-4589-8F7A-9B5BA2844DD1}.Debug|Any CPU.Build.0 = Debug|Any CPU
{91CCC7E4-70E2-4589-8F7A-9B5BA2844DD1}.Release|Any CPU.ActiveCfg = Release|Any CPU
{91CCC7E4-70E2-4589-8F7A-9B5BA2844DD1}.Release|Any CPU.Build.0 = Release|Any CPU
{302CD326-ADC3-484E-8F41-A54934A01D70}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{302CD326-ADC3-484E-8F41-A54934A01D70}.Debug|Any CPU.Build.0 = Debug|Any CPU
{302CD326-ADC3-484E-8F41-A54934A01D70}.Release|Any CPU.ActiveCfg = Release|Any CPU
{302CD326-ADC3-484E-8F41-A54934A01D70}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -348,5 +354,6 @@ Global
{01DC2D34-958F-4381-99AD-E91E3CEE31FD} = {7F35205F-1692-4702-AA88-3C29BBB121BC}
{454BA81E-861D-4908-B4D3-D1F2CDEF2C81} = {FDEE2E22-C239-4921-83B2-9797F765FD6A}
{91CCC7E4-70E2-4589-8F7A-9B5BA2844DD1} = {FDEE2E22-C239-4921-83B2-9797F765FD6A}
{302CD326-ADC3-484E-8F41-A54934A01D70} = {FDEE2E22-C239-4921-83B2-9797F765FD6A}
EndGlobalSection
EndGlobal
26 changes: 26 additions & 0 deletions src/libs/Databases/LangChain.Databases.Chroma/AsyncEnumerable.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
namespace LangChain.Databases;

/// <summary>
///
/// </summary>
public static class AsyncEnumerable
{
/// <summary>
///
/// </summary>
/// <param name="source"></param>
/// <param name="cancellationToken"></param>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
public static async ValueTask<List<T>> ToListAsync<T>(this IAsyncEnumerable<T> source, CancellationToken cancellationToken = default)
{
var result = new List<T>();

await foreach (var item in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
result.Add(item);
}

return result;
}
}
336 changes: 336 additions & 0 deletions src/libs/Databases/LangChain.Databases.Chroma/ChromaVectorStore.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,336 @@
using System.Text.Json;
using LangChain.Abstractions.Embeddings.Base;
using LangChain.Docstore;
using LangChain.VectorStores;
using Microsoft.SemanticKernel.AI.Embeddings;
using Microsoft.SemanticKernel.Connectors.Memory.Chroma;
using Microsoft.SemanticKernel.Connectors.Memory.Chroma.Http.ApiSchema;
using Microsoft.SemanticKernel.Memory;

namespace LangChain.Databases;

/// <summary>
/// ChromaDB vector store.
/// <see cref="https://api.python.langchain.com/en/latest/_modules/langchain/vectorstores/chroma.html"/>
/// </summary>
public class ChromaVectorStore : VectorStore
{
private const string LangchainDefaultCollectionName = "langchain";

// TODO: SemanticKernel impl doesn't support collection metadata. Need changes when moved to another impl
private Dictionary<string, string>? CollectionMetadata { get; } = new();

private readonly ChromaMemoryStore _store;

private readonly ChromaClient _client;
private readonly string _collectionName;
private readonly JsonSerializerOptions _jsonSerializerOptions;

/// <inheritdoc />
public ChromaVectorStore(
HttpClient httpClient,
string endpoint,
IEmbeddings embeddings,
string collectionName = LangchainDefaultCollectionName)
{
_client = new ChromaClient(httpClient, endpoint);
Embeddings = embeddings;
_collectionName = collectionName;

_store = new ChromaMemoryStore(_client);

_client.CreateCollectionAsync(_collectionName).GetAwaiter().GetResult();

_jsonSerializerOptions = new JsonSerializerOptions
{
Converters =
{
new ObjectAsPrimitiveConverter(
floatFormat: FloatFormat.Double,
unknownNumberFormat: UnknownNumberFormat.Error,
objectFormat: ObjectFormat.Expando)
},
WriteIndented = true,
};
}

/// <summary>
/// Get collection
/// </summary>
public async Task<ChromaCollectionModel?> GetCollectionAsync()
{
return await _client.GetCollectionAsync(_collectionName).ConfigureAwait(false);
}

/// <summary>
/// Delete collection
/// </summary>
public async Task DeleteCollectionAsync()
{
await _store.DeleteCollectionAsync(_collectionName).ConfigureAwait(false);
}

/// <summary>
/// Get collection
/// </summary>
public async Task<Document?> GetDocumentByIdAsync(string id)
{
var record = await _store.GetAsync(_collectionName, id, withEmbedding: true).ConfigureAwait(false);

if (record == null)
{
return null;
}

var text = record.Metadata.Text;
var metadata = DeserializeMetadata(record.Metadata);

return new Document(text, metadata);
}

/// <inheritdoc />
public override async Task<IEnumerable<string>> AddDocumentsAsync(
IEnumerable<Document> documents,
CancellationToken cancellationToken = default)
{
var documentsArray = documents.ToArray();
var texts = new string[documentsArray.Length];
var ids = new string[documentsArray.Length];
var metadatas = new Dictionary<string, object>[documentsArray.Length];
for (var index = 0; index < documentsArray.Length; index++)
{
ids[index] = Guid.NewGuid().ToString();
texts[index] = documentsArray[index].PageContent;
metadatas[index] = documentsArray[index].Metadata;
}

var result = await AddCoreAsync(texts, metadatas, ids, cancellationToken).ConfigureAwait(false);

return result;
}

/// <inheritdoc />
public override async Task<IEnumerable<string>> AddTextsAsync(
IEnumerable<string> texts,
IEnumerable<Dictionary<string, object>>? metadatas = null,
CancellationToken cancellationToken = default)
{
var textsArray = texts.ToArray();
var metadatasArray = metadatas?.ToArray() ?? new Dictionary<string, object>?[textsArray.Length];
var ids = new string[textsArray.Length];

for (var index = 0; index < textsArray.Length; index++)
{
ids[index] = Guid.NewGuid().ToString();
metadatasArray[index] ??= new Dictionary<string, object>();
}

var result = await AddCoreAsync(
textsArray,
metadatasArray as Dictionary<string, object>[],
ids,
cancellationToken)
.ConfigureAwait(false);

return result;
}

/// <inheritdoc />
public override async Task<bool> DeleteAsync(
IEnumerable<string> ids,
CancellationToken cancellationToken = default)
{
await _store.RemoveBatchAsync(_collectionName, ids, cancellationToken).ConfigureAwait(false);

return true;
}

/// <inheritdoc />
public override async Task<IEnumerable<Document>> SimilaritySearchAsync(
string query,
int k = 4,
CancellationToken cancellationToken = default)
{
var embeddings = await Embeddings
.EmbedQueryAsync(query, cancellationToken)
.ConfigureAwait(false);

var documentsWithScores = await SimilaritySearchWithVectorCoreAsync(embeddings, k, cancellationToken).ConfigureAwait(false);
var documents = documentsWithScores.Select(dws => dws.Item1);

return documents;
}

/// <inheritdoc />
public override async Task<IEnumerable<Document>> SimilaritySearchByVectorAsync(
IEnumerable<float> embedding,
int k = 4,
CancellationToken cancellationToken = default)
{
var documentsWithScores = await SimilaritySearchWithVectorCoreAsync(embedding.ToArray(), k, cancellationToken).ConfigureAwait(false);
var documents = documentsWithScores.Select(dws => dws.Item1);

return documents;
}

/// <inheritdoc />
public override async Task<IEnumerable<(Document, float)>> SimilaritySearchWithScoreAsync(
string query,
int k = 4,
CancellationToken cancellationToken = default)
{
var embeddings = await Embeddings
.EmbedQueryAsync(query, cancellationToken)
.ConfigureAwait(false);

var documentsWithScores = await SimilaritySearchWithVectorCoreAsync(embeddings, k, cancellationToken).ConfigureAwait(false);

return documentsWithScores;
}

/// <inheritdoc />
public override async Task<IEnumerable<Document>> MaxMarginalRelevanceSearchByVector(
IEnumerable<float> embedding,
int k = 4,
int fetchK = 20,
float lambdaMult = 0.5f,
CancellationToken cancellationToken = default)
=> throw new NotSupportedException("Querying not supported by SemanticKernel impl.");

/// <inheritdoc />
public override async Task<IEnumerable<Document>> MaxMarginalRelevanceSearch(
string query,
int k = 4,
int fetchK = 20,
float lambdaMult = 0.5f,
CancellationToken cancellationToken = default)
{
var embeddings = await Embeddings
.EmbedQueryAsync(query, cancellationToken)
.ConfigureAwait(false);

var documents = await MaxMarginalRelevanceSearchByVector(
embeddings,
k,
fetchK,
lambdaMult,
cancellationToken)
.ConfigureAwait(false);

return documents;
}

/// <inheritdoc />
protected override Func<float, float> SelectRelevanceScoreFn()
{
if (OverrideRelevanceScoreFn != null)
{
return OverrideRelevanceScoreFn;
}

return GetFuncForDistance();

Func<float, float> GetFuncForDistance()
{
return distance =>
{
const string distanceKey = "hnsw:space";
var distanceType = "l2";
if (CollectionMetadata != null
&& CollectionMetadata.TryGetValue(distanceKey, out var value))
{
distanceType = value;
}
return distanceType switch
{
"cosine" => CosineRelevanceScoreFn(distance),
"l2" => EuclideanRelevanceScoreFn(distance),
"ip" => MaxInnerProductRelevanceScoreFn(distance),
_ => throw new ArgumentException(
$@"No supported normalization function for distance metric of type: {distanceType}.
Consider providing relevance_score_fn to Chroma constructor.")
};
};
}
}

private async Task<IEnumerable<string>> AddCoreAsync(
string[] texts,
Dictionary<string, object>[] metadatas,
string[] ids,
CancellationToken cancellationToken)
{
var embeddings = await Embeddings
.EmbedDocumentsAsync(texts, cancellationToken)
.ConfigureAwait(false);

var records = new MemoryRecord[texts.Length];
for (var index = 0; index < texts.Length; index++)
{
// TODO: check: description, externalSourceName, key
records[index] = new MemoryRecord
(
new MemoryRecordMetadata
(
isReference: false,
id: ids[index],
text: texts[index],
description: string.Empty,
externalSourceName: string.Empty,
additionalMetadata: SerializeMetadata(metadatas[index])
),
new Embedding<float>(embeddings[index]),
key: null
);
}

var resultIds = new List<string>(texts.Length);
var resultIdsIterator = _store.UpsertBatchAsync(_collectionName, records, cancellationToken);
await foreach (var item in resultIdsIterator.ConfigureAwait(false))
{
resultIds.Add(item);
}

return resultIds;
}

private async Task<IEnumerable<(Document, float)>> SimilaritySearchWithVectorCoreAsync(
float[] embeddings,
int k,
CancellationToken cancellationToken)
{
var matches = await _store
.GetNearestMatchesAsync(
_collectionName,
new Embedding<float>(embeddings),
k,
cancellationToken: cancellationToken)
.ToListAsync(cancellationToken)
.ConfigureAwait(false);

return matches.Select(
record =>
{
var text = record.Item1.Metadata.Text;
var metadata = DeserializeMetadata(record.Item1.Metadata);
return (new Document(text, metadata), (float)record.Item2);
});
}

private string SerializeMetadata(Dictionary<string, object> metadata)
{
return JsonSerializer.Serialize(metadata, _jsonSerializerOptions);
}

private Dictionary<string, object> DeserializeMetadata(MemoryRecordMetadata metadata)
{
// TODO: issue with this method is it returns values as JsonElements instead of primitive types
return JsonSerializer.Deserialize<Dictionary<string, object>>(metadata.AdditionalMetadata, _jsonSerializerOptions)
?? new Dictionary<string, object>();
}
}
Loading

0 comments on commit 4ab15b9

Please sign in to comment.