-
-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: BaseRetriever, VectorStore, VectorStoreRetriever and Chroma (#42)
* vectorstore and chroma * tests * fix rebase artefacts * tests * supress local tests and primitive serialize/deserialize * converter see cref --------- Co-authored-by: Evgenii Khoroshev <[email protected]> Co-authored-by: Konstantin S <[email protected]>
- Loading branch information
1 parent
d7d1544
commit 4ab15b9
Showing
32 changed files
with
1,350 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
26 changes: 26 additions & 0 deletions
26
src/libs/Databases/LangChain.Databases.Chroma/AsyncEnumerable.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
namespace LangChain.Databases; | ||
|
||
/// <summary> | ||
/// | ||
/// </summary> | ||
public static class AsyncEnumerable | ||
{ | ||
/// <summary> | ||
/// | ||
/// </summary> | ||
/// <param name="source"></param> | ||
/// <param name="cancellationToken"></param> | ||
/// <typeparam name="T"></typeparam> | ||
/// <returns></returns> | ||
public static async ValueTask<List<T>> ToListAsync<T>(this IAsyncEnumerable<T> source, CancellationToken cancellationToken = default) | ||
{ | ||
var result = new List<T>(); | ||
|
||
await foreach (var item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) | ||
{ | ||
result.Add(item); | ||
} | ||
|
||
return result; | ||
} | ||
} |
336 changes: 336 additions & 0 deletions
336
src/libs/Databases/LangChain.Databases.Chroma/ChromaVectorStore.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,336 @@ | ||
using System.Text.Json; | ||
using LangChain.Abstractions.Embeddings.Base; | ||
using LangChain.Docstore; | ||
using LangChain.VectorStores; | ||
using Microsoft.SemanticKernel.AI.Embeddings; | ||
using Microsoft.SemanticKernel.Connectors.Memory.Chroma; | ||
using Microsoft.SemanticKernel.Connectors.Memory.Chroma.Http.ApiSchema; | ||
using Microsoft.SemanticKernel.Memory; | ||
|
||
namespace LangChain.Databases; | ||
|
||
/// <summary> | ||
/// ChromaDB vector store. | ||
/// <see cref="https://api.python.langchain.com/en/latest/_modules/langchain/vectorstores/chroma.html"/> | ||
/// </summary> | ||
public class ChromaVectorStore : VectorStore | ||
{ | ||
private const string LangchainDefaultCollectionName = "langchain"; | ||
|
||
// TODO: SemanticKernel impl doesn't support collection metadata. Need changes when moved to another impl | ||
private Dictionary<string, string>? CollectionMetadata { get; } = new(); | ||
|
||
private readonly ChromaMemoryStore _store; | ||
|
||
private readonly ChromaClient _client; | ||
private readonly string _collectionName; | ||
private readonly JsonSerializerOptions _jsonSerializerOptions; | ||
|
||
/// <inheritdoc /> | ||
public ChromaVectorStore( | ||
HttpClient httpClient, | ||
string endpoint, | ||
IEmbeddings embeddings, | ||
string collectionName = LangchainDefaultCollectionName) | ||
{ | ||
_client = new ChromaClient(httpClient, endpoint); | ||
Embeddings = embeddings; | ||
_collectionName = collectionName; | ||
|
||
_store = new ChromaMemoryStore(_client); | ||
|
||
_client.CreateCollectionAsync(_collectionName).GetAwaiter().GetResult(); | ||
|
||
_jsonSerializerOptions = new JsonSerializerOptions | ||
{ | ||
Converters = | ||
{ | ||
new ObjectAsPrimitiveConverter( | ||
floatFormat: FloatFormat.Double, | ||
unknownNumberFormat: UnknownNumberFormat.Error, | ||
objectFormat: ObjectFormat.Expando) | ||
}, | ||
WriteIndented = true, | ||
}; | ||
} | ||
|
||
/// <summary> | ||
/// Get collection | ||
/// </summary> | ||
public async Task<ChromaCollectionModel?> GetCollectionAsync() | ||
{ | ||
return await _client.GetCollectionAsync(_collectionName).ConfigureAwait(false); | ||
} | ||
|
||
/// <summary> | ||
/// Delete collection | ||
/// </summary> | ||
public async Task DeleteCollectionAsync() | ||
{ | ||
await _store.DeleteCollectionAsync(_collectionName).ConfigureAwait(false); | ||
} | ||
|
||
/// <summary> | ||
/// Get collection | ||
/// </summary> | ||
public async Task<Document?> GetDocumentByIdAsync(string id) | ||
{ | ||
var record = await _store.GetAsync(_collectionName, id, withEmbedding: true).ConfigureAwait(false); | ||
|
||
if (record == null) | ||
{ | ||
return null; | ||
} | ||
|
||
var text = record.Metadata.Text; | ||
var metadata = DeserializeMetadata(record.Metadata); | ||
|
||
return new Document(text, metadata); | ||
} | ||
|
||
/// <inheritdoc /> | ||
public override async Task<IEnumerable<string>> AddDocumentsAsync( | ||
IEnumerable<Document> documents, | ||
CancellationToken cancellationToken = default) | ||
{ | ||
var documentsArray = documents.ToArray(); | ||
var texts = new string[documentsArray.Length]; | ||
var ids = new string[documentsArray.Length]; | ||
var metadatas = new Dictionary<string, object>[documentsArray.Length]; | ||
for (var index = 0; index < documentsArray.Length; index++) | ||
{ | ||
ids[index] = Guid.NewGuid().ToString(); | ||
texts[index] = documentsArray[index].PageContent; | ||
metadatas[index] = documentsArray[index].Metadata; | ||
} | ||
|
||
var result = await AddCoreAsync(texts, metadatas, ids, cancellationToken).ConfigureAwait(false); | ||
|
||
return result; | ||
} | ||
|
||
/// <inheritdoc /> | ||
public override async Task<IEnumerable<string>> AddTextsAsync( | ||
IEnumerable<string> texts, | ||
IEnumerable<Dictionary<string, object>>? metadatas = null, | ||
CancellationToken cancellationToken = default) | ||
{ | ||
var textsArray = texts.ToArray(); | ||
var metadatasArray = metadatas?.ToArray() ?? new Dictionary<string, object>?[textsArray.Length]; | ||
var ids = new string[textsArray.Length]; | ||
|
||
for (var index = 0; index < textsArray.Length; index++) | ||
{ | ||
ids[index] = Guid.NewGuid().ToString(); | ||
metadatasArray[index] ??= new Dictionary<string, object>(); | ||
} | ||
|
||
var result = await AddCoreAsync( | ||
textsArray, | ||
metadatasArray as Dictionary<string, object>[], | ||
ids, | ||
cancellationToken) | ||
.ConfigureAwait(false); | ||
|
||
return result; | ||
} | ||
|
||
/// <inheritdoc /> | ||
public override async Task<bool> DeleteAsync( | ||
IEnumerable<string> ids, | ||
CancellationToken cancellationToken = default) | ||
{ | ||
await _store.RemoveBatchAsync(_collectionName, ids, cancellationToken).ConfigureAwait(false); | ||
|
||
return true; | ||
} | ||
|
||
/// <inheritdoc /> | ||
public override async Task<IEnumerable<Document>> SimilaritySearchAsync( | ||
string query, | ||
int k = 4, | ||
CancellationToken cancellationToken = default) | ||
{ | ||
var embeddings = await Embeddings | ||
.EmbedQueryAsync(query, cancellationToken) | ||
.ConfigureAwait(false); | ||
|
||
var documentsWithScores = await SimilaritySearchWithVectorCoreAsync(embeddings, k, cancellationToken).ConfigureAwait(false); | ||
var documents = documentsWithScores.Select(dws => dws.Item1); | ||
|
||
return documents; | ||
} | ||
|
||
/// <inheritdoc /> | ||
public override async Task<IEnumerable<Document>> SimilaritySearchByVectorAsync( | ||
IEnumerable<float> embedding, | ||
int k = 4, | ||
CancellationToken cancellationToken = default) | ||
{ | ||
var documentsWithScores = await SimilaritySearchWithVectorCoreAsync(embedding.ToArray(), k, cancellationToken).ConfigureAwait(false); | ||
var documents = documentsWithScores.Select(dws => dws.Item1); | ||
|
||
return documents; | ||
} | ||
|
||
/// <inheritdoc /> | ||
public override async Task<IEnumerable<(Document, float)>> SimilaritySearchWithScoreAsync( | ||
string query, | ||
int k = 4, | ||
CancellationToken cancellationToken = default) | ||
{ | ||
var embeddings = await Embeddings | ||
.EmbedQueryAsync(query, cancellationToken) | ||
.ConfigureAwait(false); | ||
|
||
var documentsWithScores = await SimilaritySearchWithVectorCoreAsync(embeddings, k, cancellationToken).ConfigureAwait(false); | ||
|
||
return documentsWithScores; | ||
} | ||
|
||
/// <inheritdoc /> | ||
public override async Task<IEnumerable<Document>> MaxMarginalRelevanceSearchByVector( | ||
IEnumerable<float> embedding, | ||
int k = 4, | ||
int fetchK = 20, | ||
float lambdaMult = 0.5f, | ||
CancellationToken cancellationToken = default) | ||
=> throw new NotSupportedException("Querying not supported by SemanticKernel impl."); | ||
|
||
/// <inheritdoc /> | ||
public override async Task<IEnumerable<Document>> MaxMarginalRelevanceSearch( | ||
string query, | ||
int k = 4, | ||
int fetchK = 20, | ||
float lambdaMult = 0.5f, | ||
CancellationToken cancellationToken = default) | ||
{ | ||
var embeddings = await Embeddings | ||
.EmbedQueryAsync(query, cancellationToken) | ||
.ConfigureAwait(false); | ||
|
||
var documents = await MaxMarginalRelevanceSearchByVector( | ||
embeddings, | ||
k, | ||
fetchK, | ||
lambdaMult, | ||
cancellationToken) | ||
.ConfigureAwait(false); | ||
|
||
return documents; | ||
} | ||
|
||
/// <inheritdoc /> | ||
protected override Func<float, float> SelectRelevanceScoreFn() | ||
{ | ||
if (OverrideRelevanceScoreFn != null) | ||
{ | ||
return OverrideRelevanceScoreFn; | ||
} | ||
|
||
return GetFuncForDistance(); | ||
|
||
Func<float, float> GetFuncForDistance() | ||
{ | ||
return distance => | ||
{ | ||
const string distanceKey = "hnsw:space"; | ||
var distanceType = "l2"; | ||
if (CollectionMetadata != null | ||
&& CollectionMetadata.TryGetValue(distanceKey, out var value)) | ||
{ | ||
distanceType = value; | ||
} | ||
return distanceType switch | ||
{ | ||
"cosine" => CosineRelevanceScoreFn(distance), | ||
"l2" => EuclideanRelevanceScoreFn(distance), | ||
"ip" => MaxInnerProductRelevanceScoreFn(distance), | ||
_ => throw new ArgumentException( | ||
$@"No supported normalization function for distance metric of type: {distanceType}. | ||
Consider providing relevance_score_fn to Chroma constructor.") | ||
}; | ||
}; | ||
} | ||
} | ||
|
||
private async Task<IEnumerable<string>> AddCoreAsync( | ||
string[] texts, | ||
Dictionary<string, object>[] metadatas, | ||
string[] ids, | ||
CancellationToken cancellationToken) | ||
{ | ||
var embeddings = await Embeddings | ||
.EmbedDocumentsAsync(texts, cancellationToken) | ||
.ConfigureAwait(false); | ||
|
||
var records = new MemoryRecord[texts.Length]; | ||
for (var index = 0; index < texts.Length; index++) | ||
{ | ||
// TODO: check: description, externalSourceName, key | ||
records[index] = new MemoryRecord | ||
( | ||
new MemoryRecordMetadata | ||
( | ||
isReference: false, | ||
id: ids[index], | ||
text: texts[index], | ||
description: string.Empty, | ||
externalSourceName: string.Empty, | ||
additionalMetadata: SerializeMetadata(metadatas[index]) | ||
), | ||
new Embedding<float>(embeddings[index]), | ||
key: null | ||
); | ||
} | ||
|
||
var resultIds = new List<string>(texts.Length); | ||
var resultIdsIterator = _store.UpsertBatchAsync(_collectionName, records, cancellationToken); | ||
await foreach (var item in resultIdsIterator.ConfigureAwait(false)) | ||
{ | ||
resultIds.Add(item); | ||
} | ||
|
||
return resultIds; | ||
} | ||
|
||
private async Task<IEnumerable<(Document, float)>> SimilaritySearchWithVectorCoreAsync( | ||
float[] embeddings, | ||
int k, | ||
CancellationToken cancellationToken) | ||
{ | ||
var matches = await _store | ||
.GetNearestMatchesAsync( | ||
_collectionName, | ||
new Embedding<float>(embeddings), | ||
k, | ||
cancellationToken: cancellationToken) | ||
.ToListAsync(cancellationToken) | ||
.ConfigureAwait(false); | ||
|
||
return matches.Select( | ||
record => | ||
{ | ||
var text = record.Item1.Metadata.Text; | ||
var metadata = DeserializeMetadata(record.Item1.Metadata); | ||
return (new Document(text, metadata), (float)record.Item2); | ||
}); | ||
} | ||
|
||
private string SerializeMetadata(Dictionary<string, object> metadata) | ||
{ | ||
return JsonSerializer.Serialize(metadata, _jsonSerializerOptions); | ||
} | ||
|
||
private Dictionary<string, object> DeserializeMetadata(MemoryRecordMetadata metadata) | ||
{ | ||
// TODO: issue with this method is it returns values as JsonElements instead of primitive types | ||
return JsonSerializer.Deserialize<Dictionary<string, object>>(metadata.AdditionalMetadata, _jsonSerializerOptions) | ||
?? new Dictionary<string, object>(); | ||
} | ||
} |
Oops, something went wrong.