Skip to content

Commit

Permalink
Adding IVectorDatabase implementation for mongo
Browse files Browse the repository at this point in the history
  • Loading branch information
vikhyat90 committed May 30, 2024
1 parent b2aa602 commit 2cf5e63
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 3 deletions.
20 changes: 20 additions & 0 deletions src/Databases/IntegrationTests/DatabaseTests.Configure.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
using LangChain.Databases.OpenSearch;
using LangChain.Databases.Postgres;
using LangChain.Databases.Sqlite;
using LangChain.Databases.Mongo;
using Testcontainers.MongoDb;
using Testcontainers.PostgreSql;

namespace LangChain.Databases.IntegrationTests;
Expand Down Expand Up @@ -101,6 +103,24 @@ private static async Task<DatabaseTestEnvironment> StartEnvironmentForAsync(Supp
Port = port2,
};
}

case SupportedDatabase.Mongo:
{
var port = Random.Shared.Next(49152, 65535);
var container = new MongoDbBuilder()
.WithImage("mongo")
.WithPortBinding(hostPort: port, containerPort: 27017)
.WithWaitStrategy(Wait.ForUnixContainer().UntilPortIsAvailable(27017))
.Build();

await container.StartAsync(cancellationToken);

return new DatabaseTestEnvironment
{
VectorDatabase = new MongoVectorDatabase(container.GetConnectionString()),
Container = container,
};
}
default:
throw new ArgumentOutOfRangeException(nameof(database), database, null);
}
Expand Down
7 changes: 7 additions & 0 deletions src/Databases/IntegrationTests/DatabaseTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ public partial class DatabaseTests
[TestCase(SupportedDatabase.OpenSearch)]
[TestCase(SupportedDatabase.Postgres)]
[TestCase(SupportedDatabase.SqLite)]
[TestCase(SupportedDatabase.Mongo)]
public async Task CreateAndDeleteCollection_Ok(SupportedDatabase database)
{
await using var environment = await StartEnvironmentForAsync(database);
Expand Down Expand Up @@ -56,6 +57,7 @@ await vectorDatabase.Invoking(y => y.GetCollectionAsync(environment.CollectionNa
[TestCase(SupportedDatabase.OpenSearch)]
[TestCase(SupportedDatabase.Postgres)]
[TestCase(SupportedDatabase.SqLite)]
[TestCase(SupportedDatabase.Mongo)]
public async Task AddDocuments_Ok(SupportedDatabase database)
{
await using var environment = await StartEnvironmentForAsync(database);
Expand Down Expand Up @@ -100,6 +102,7 @@ public async Task AddDocuments_Ok(SupportedDatabase database)
[TestCase(SupportedDatabase.OpenSearch)]
[TestCase(SupportedDatabase.Postgres)]
[TestCase(SupportedDatabase.SqLite)]
[TestCase(SupportedDatabase.Mongo)]
public async Task AddTexts_Ok(SupportedDatabase database)
{
await using var environment = await StartEnvironmentForAsync(database);
Expand Down Expand Up @@ -148,6 +151,7 @@ public async Task AddTexts_Ok(SupportedDatabase database)
[TestCase(SupportedDatabase.OpenSearch)]
[TestCase(SupportedDatabase.Postgres)]
[TestCase(SupportedDatabase.SqLite)]
[TestCase(SupportedDatabase.Mongo)]
public async Task DeleteDocuments_Ok(SupportedDatabase database)
{
await using var environment = await StartEnvironmentForAsync(database);
Expand Down Expand Up @@ -185,6 +189,7 @@ public async Task DeleteDocuments_Ok(SupportedDatabase database)
[TestCase(SupportedDatabase.OpenSearch)]
[TestCase(SupportedDatabase.Postgres)]
[TestCase(SupportedDatabase.SqLite)]
[TestCase(SupportedDatabase.Mongo)]
public async Task SimilaritySearch_Ok(SupportedDatabase database)
{
await using var environment = await StartEnvironmentForAsync(database);
Expand Down Expand Up @@ -216,6 +221,7 @@ public async Task SimilaritySearch_Ok(SupportedDatabase database)
[TestCase(SupportedDatabase.OpenSearch)]
[TestCase(SupportedDatabase.Postgres)]
[TestCase(SupportedDatabase.SqLite)]
[TestCase(SupportedDatabase.Mongo)]
public async Task SimilaritySearchByVector_Ok(SupportedDatabase database)
{
await using var environment = await StartEnvironmentForAsync(database);
Expand Down Expand Up @@ -243,6 +249,7 @@ public async Task SimilaritySearchByVector_Ok(SupportedDatabase database)
[TestCase(SupportedDatabase.OpenSearch)]
[TestCase(SupportedDatabase.Postgres)]
[TestCase(SupportedDatabase.SqLite)]
[TestCase(SupportedDatabase.Mongo)]
public async Task SimilaritySearchWithScores_Ok(SupportedDatabase database)
{
await using var environment = await StartEnvironmentForAsync(database);
Expand Down
2 changes: 2 additions & 0 deletions src/Databases/Mongo/src/Client/IMongoContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ namespace LangChain.Databases.Mongo.Client;
public interface IMongoContext
{
IMongoCollection<T> GetCollection<T>(string name);
Task<List<string>> GetCollections();
IMongoDatabase GetDatabase();
}
9 changes: 8 additions & 1 deletion src/Databases/Mongo/src/Client/IMongoDbClient.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LangChain.Databases.Mongo.Model;
using MongoDB.Driver;
using System.Linq.Expressions;

namespace LangChain.Databases.Mongo.Client;
Expand All @@ -7,7 +8,13 @@ public interface IMongoDbClient
{
Task BatchDeactivate<T>(Expression<Func<T, bool>> filter)
where T : BaseEntity;

bool CollectionExists(string collectionName);
Task<bool> CollectionExistsAsync(string collectionName);
Task<IMongoCollection<T>> CreateCollection<T>(string collectionName);
Task DropCollectionAsync(string collectionName);
Task<IEnumerable<TProjected>> Get<T, TProjected>(Expression<Func<T, bool>> filter, Expression<Func<T, TProjected>> projectionExpression) where T : BaseEntity;
IMongoCollection<T> GetCollection<T>();
Task<List<string>> GetCollections();
IEnumerable<TProjected> GetSync<T, TProjected>(
Expression<Func<T, bool>> filter,
Expression<Func<T, TProjected>> projectionExpression)
Expand Down
27 changes: 25 additions & 2 deletions src/Databases/Mongo/src/Client/MongoContext.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
using MongoDB.Driver;
using MongoDB.Bson;
using MongoDB.Driver;
using MongoDB.Driver.Core.Authentication;
using MongoDB.Driver.Core.Configuration;

namespace LangChain.Databases.Mongo.Client;

public class MongoContext : IMongoContext
{
private readonly IMongoDatabase _mongoDatabase;
public readonly IMongoDatabase _mongoDatabase;

Check warning on line 10 in src/Databases/Mongo/src/Client/MongoContext.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Identifier '_mongoDatabase' is not CLS-compliant

public MongoContext(IDatabaseConfiguration databaseConfiguration)
{
Expand All @@ -19,5 +22,25 @@ public IMongoCollection<T> GetCollection<T>(string name)
name = name ?? throw new ArgumentNullException(nameof(name));

return _mongoDatabase.GetCollection<T>(name);
}

public IMongoDatabase GetDatabase()
{
return _mongoDatabase;
}

public async Task<List<string>> GetCollections()
{
List<string> collectionNames = new List<string>();
var collections = await _mongoDatabase.ListCollectionsAsync();

foreach (BsonDocument collection in await collections.ToListAsync<BsonDocument>())
{
string name = collection["name"].AsString;
collectionNames.Add(name);
}

return collectionNames;
}

}
42 changes: 42 additions & 0 deletions src/Databases/Mongo/src/Client/MongoDbClient.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
using LangChain.Databases.Mongo.Model;
using MongoDB.Bson;
using MongoDB.Driver;
using MongoDB.Driver.Core.Configuration;
using System.Linq.Expressions;
using System.Xml.Linq;

namespace LangChain.Databases.Mongo.Client;

public class MongoDbClient(IMongoContext mongoContext) : IMongoDbClient
{

public async Task BatchDeactivate<T>(Expression<Func<T, bool>> filter) where T : BaseEntity
{
var entityIds = (await Get(filter, p => p.Id).ConfigureAwait(false)).ToList();
Expand Down Expand Up @@ -53,4 +57,42 @@ public async Task InsertAsync<T>(T entity) where T : BaseEntity

await GetCollection<T>().InsertOneAsync(entity).ConfigureAwait(false);
}

public async Task<bool> CollectionExistsAsync(string collectionName)
{
var filter = new BsonDocument("name", collectionName);
var options = new ListCollectionNamesOptions { Filter = filter };

var collections = await mongoContext.GetDatabase().ListCollectionNamesAsync(options).ConfigureAwait(false);

return await collections.AnyAsync();
}

public bool CollectionExists(string collectionName)
{
var filter = new BsonDocument("name", collectionName);
var options = new ListCollectionNamesOptions { Filter = filter };
return mongoContext.GetDatabase().ListCollectionNames(options).Any();
}

public async Task<List<string>> GetCollections()
{
return await mongoContext.GetCollections();
}

public async Task<IMongoCollection<T>> CreateCollection<T>(string collectionName)
{
await mongoContext.GetDatabase().CreateCollectionAsync(collectionName, new CreateCollectionOptions
{
AutoIndexId = true

Check warning on line 87 in src/Databases/Mongo/src/Client/MongoDbClient.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

'CreateCollectionOptions.AutoIndexId' is obsolete: 'AutoIndexId has been deprecated since server version 3.2.'
}).ConfigureAwait(false);

var collection = mongoContext.GetCollection<T>(collectionName);
return collection;
}

public async Task DropCollectionAsync(string collectionName)
{
await mongoContext.GetDatabase().DropCollectionAsync(collectionName);
}
}
86 changes: 86 additions & 0 deletions src/Databases/Mongo/src/MongoVectorDatabase.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using LangChain.Databases.Mongo.Client;
using MongoDB.Driver;

namespace LangChain.Databases.Mongo
{
public class MongoVectorDatabase(
string connectionString,
string schema = MongoVectorDatabase.DefaultSchema)
: IVectorDatabase
{
private const string DefaultSchema = "langchain";

private readonly IMongoDbClient _client = new MongoDbClient(
new MongoContext(
new DatabaseConfiguration
{
ConnectionString = connectionString,
DatabaseName = schema,
}));

/// <inheritdoc />
public async Task<IVectorCollection> GetCollectionAsync(string collectionName, CancellationToken cancellationToken = default)
{
if (!await IsCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false))
{
throw new InvalidOperationException($"Collection '{collectionName}' does not exist.");
}

var context = new MongoContext(new DatabaseConfiguration
{
ConnectionString = connectionString,
DatabaseName = schema,
});

return new MongoVectorCollection(context, "idx_"+collectionName, name: collectionName);
}

/// <inheritdoc />
public async Task DeleteCollectionAsync(string collectionName, CancellationToken cancellationToken = default)
{
await _client.DropCollectionAsync(collectionName).ConfigureAwait(false);
}

/// <inheritdoc />
public async Task<IVectorCollection> GetOrCreateCollectionAsync(string collectionName, int dimensions, CancellationToken cancellationToken = default)
{
if (!await IsCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false))
{
await CreateCollectionAsync(collectionName, dimensions, cancellationToken).ConfigureAwait(false);
}

return await GetCollectionAsync(collectionName, cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc />
public async Task<bool> IsCollectionExistsAsync(string collectionName, CancellationToken cancellationToken = default)
{
return await _client.CollectionExistsAsync(collectionName).ConfigureAwait(false);
}

/// <inheritdoc />
public async Task CreateCollectionAsync(string collectionName, int dimensions, CancellationToken cancellationToken = default)
{
var collection = await _client.CreateCollection<Vector>(collectionName).ConfigureAwait(false);
var indexName = await collection.Indexes.CreateOneAsync(new CreateIndexModel<Vector>(
Builders<Vector>.IndexKeys.Ascending(v => v.Embedding)
.Ascending(v => v.Text), new CreateIndexOptions
{
Background = true,
}), cancellationToken: cancellationToken).ConfigureAwait(false);
return;
}

/// <inheritdoc />
public async Task<IReadOnlyList<string>> ListCollectionsAsync(CancellationToken cancellationToken = default)
{
return await _client.GetCollections().ConfigureAwait(false);
}
}
}

0 comments on commit 2cf5e63

Please sign in to comment.