Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add disposal of NpgsqlDataSource after opening connection #771

Merged
merged 5 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.KernelMemory;
using Microsoft.KernelMemory.MemoryStorage;
Expand Down Expand Up @@ -71,7 +71,7 @@ last_update TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL
var indexName = "create_index_test";
var vectorSize = 1536;

using var target = new PostgresMemory(config, new FakeEmbeddingGenerator());
var target = new PostgresMemory(config, new FakeEmbeddingGenerator());

var tasks = new List<Task>();
for (int i = 0; i < concurrency; i++)
Expand All @@ -98,7 +98,7 @@ public async Task UpsertConcurrencyTest()
var vectorSize = 4;
var indexName = "upsert_test" + Guid.NewGuid().ToString("D");

using var target = new PostgresMemory(this.PostgresConfig, new FakeEmbeddingGenerator());
var target = new PostgresMemory(this.PostgresConfig, new FakeEmbeddingGenerator());

await target.CreateIndexAsync(indexName, vectorSize);

Expand Down
52 changes: 19 additions & 33 deletions extensions/Postgres/Postgres/Internals/PostgresDbClient.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
Expand All @@ -19,7 +19,7 @@ namespace Microsoft.KernelMemory.Postgres;
/// <summary>
/// An implementation of a client for Postgres. This class is used to managing postgres database operations.
/// </summary>
internal sealed class PostgresDbClient : IDisposable
internal sealed class PostgresDbClient
{
// See: https://www.postgresql.org/docs/current/errcodes-appendix.html
private const string PgErrUndefinedTable = "42P01"; // undefined_table
Expand All @@ -28,7 +28,7 @@ internal sealed class PostgresDbClient : IDisposable
private const string PgErrDatabaseDoesNotExist = "3D000"; // invalid_catalog_name

private readonly ILogger _log;
private readonly NpgsqlDataSource _dataSource;
private readonly NpgsqlDataSourceBuilder _dataSourceBuilder;

marcominerva marked this conversation as resolved.
Show resolved Hide resolved
private readonly string _schema;
private readonly string _tableNamePrefix;
Expand All @@ -52,10 +52,10 @@ public PostgresDbClient(PostgresConfig config, ILoggerFactory? loggerFactory = n
config.Validate();
this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger<PostgresDbClient>();

NpgsqlDataSourceBuilder dataSourceBuilder = new(config.ConnectionString);
this._dataSourceBuilder = new(config.ConnectionString);
this._dataSourceBuilder.UseVector();

this._dbNamePresent = config.ConnectionString.Contains("Database=", StringComparison.OrdinalIgnoreCase);
dataSourceBuilder.UseVector();
this._dataSource = dataSourceBuilder.Build();
this._schema = config.Schema;
this._tableNamePrefix = config.TableNamePrefix;

Expand Down Expand Up @@ -228,8 +228,8 @@ public async Task CreateTableAsync(
if (await this.DoesTableExistAsync(origInputTableName, cancellationToken).ConfigureAwait(false))
{
// Check if the custom SQL contains the lock placeholder (assuming it's not commented out)
bool missingLockStatement = (!string.IsNullOrEmpty(this._createTableSql)
&& !this._createTableSql.Contains(PostgresConfig.SqlPlaceholdersLockId, StringComparison.Ordinal));
bool missingLockStatement = !string.IsNullOrEmpty(this._createTableSql)
&& !this._createTableSql.Contains(PostgresConfig.SqlPlaceholdersLockId, StringComparison.Ordinal);

if (missingLockStatement)
{
Expand Down Expand Up @@ -658,24 +658,6 @@ public async Task DeleteAsync(
}
}

/// <inheritdoc/>
public void Dispose()
{
this.Dispose(true);
GC.SuppressFinalize(this);
}

/// <summary>
/// Disposes the managed resources
/// </summary>
private void Dispose(bool disposing)
{
if (disposing)
{
(this._dataSource as IDisposable)?.Dispose();
}
}

/// <summary>
/// Try to connect to PG, handling exceptions in case the DB doesn't exist
/// </summary>
Expand All @@ -685,7 +667,11 @@ private async Task<NpgsqlConnection> ConnectAsync(CancellationToken cancellation
{
try
{
return await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false);
var dataSource = this._dataSourceBuilder.Build();
await using (dataSource.ConfigureAwait(false))
{
return await dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false);
}
}
catch (Npgsql.PostgresException e) when (IsDbNotFoundException(e))
{
Expand Down Expand Up @@ -750,20 +736,20 @@ private string WithTableNamePrefix(string tableName)

private static bool IsDbNotFoundException(Npgsql.PostgresException e)
{
return (e.SqlState == PgErrDatabaseDoesNotExist);
return e.SqlState == PgErrDatabaseDoesNotExist;
}

private static bool IsTableNotFoundException(Npgsql.PostgresException e)
{
return (e.SqlState == PgErrUndefinedTable || e.Message.Contains("does not exist", StringComparison.OrdinalIgnoreCase));
return e.SqlState == PgErrUndefinedTable || e.Message.Contains("does not exist", StringComparison.OrdinalIgnoreCase);
}

private static bool IsVectorTypeDoesNotExistException(Npgsql.PostgresException e)
{
return (e.SqlState == PgErrTypeDoesNotExist
&& e.Message.Contains("type", StringComparison.OrdinalIgnoreCase)
&& e.Message.Contains("vector", StringComparison.OrdinalIgnoreCase)
&& e.Message.Contains("does not exist", StringComparison.OrdinalIgnoreCase));
return e.SqlState == PgErrTypeDoesNotExist
&& e.Message.Contains("type", StringComparison.OrdinalIgnoreCase)
marcominerva marked this conversation as resolved.
Show resolved Hide resolved
&& e.Message.Contains("vector", StringComparison.OrdinalIgnoreCase)
&& e.Message.Contains("does not exist", StringComparison.OrdinalIgnoreCase);
}

/// <summary>
Expand Down
20 changes: 1 addition & 19 deletions extensions/Postgres/Postgres/PostgresMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace Microsoft.KernelMemory.Postgres;
/// Postgres connector for Kernel Memory.
/// </summary>
[Experimental("KMEXP03")]
public sealed class PostgresMemory : IMemoryDb, IDisposable
public sealed class PostgresMemory : IMemoryDb
{
private readonly ILogger<PostgresMemory> _log;
private readonly ITextEmbeddingGenerator _embeddingGenerator;
Expand Down Expand Up @@ -209,24 +209,6 @@ public Task DeleteAsync(
return this._db.DeleteAsync(tableName: index, id: record.Id, cancellationToken);
}

/// <inheritdoc/>
public void Dispose()
{
this.Dispose(true);
GC.SuppressFinalize(this);
}

/// <summary>
/// Disposes the managed resources.
/// </summary>
private void Dispose(bool disposing)
{
if (disposing)
{
(this._db as IDisposable)?.Dispose();
}
}

#region private ================================================================================

// Note: "_" is allowed in Postgres, but we normalize it to "-" for consistency with other DBs
Expand Down
Loading