Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added ImageToText abstractions and HuggingFace implementation. #152

Merged
merged 5 commits into from
Mar 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions examples/LangChain.Samples.HuggingFace/Program.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,29 @@
using LangChain.Providers.HuggingFace;
using LangChain.Providers;
using LangChain.Providers.HuggingFace;
using LangChain.Providers.HuggingFace.Predefined;

using var client = new HttpClient();
var provider = new HuggingFaceProvider(apiKey: string.Empty, client);
var gpt2Model = new Gpt2Model(provider);

var response = await gpt2Model.GenerateAsync("What would be a good company name be for name a company that makes colorful socks?");
var gp2ModelResponse = await gpt2Model.GenerateAsync("What would be a good company name be for name a company that makes colorful socks?");

Console.WriteLine(response);
Console.WriteLine("### GP2 Response");
Console.WriteLine(gp2ModelResponse);

const string imageToTextModel = "Salesforce/blip-image-captioning-base";
var model = new HuggingFaceImageToTextModel(provider, imageToTextModel);

var path = Path.Combine(Path.GetTempPath(), "solar_system.png");
var imageData = await File.ReadAllBytesAsync(path);
var binaryData = new BinaryData(imageData, "image/jpg");

var imageToTextResponse = await model.GenerateTextFromImageAsync(new ImageToTextRequest
{
Image = binaryData
});

Console.WriteLine("\n\n### ImageToText Response");
Console.WriteLine(imageToTextResponse.Text);

Console.ReadLine();
16 changes: 16 additions & 0 deletions src/Core/src/Chains/Chain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using LangChain.Chains.StackableChains.Agents.Crew;
using LangChain.Chains.StackableChains.Files;
using LangChain.Chains.StackableChains.ImageGeneration;
using LangChain.Chains.StackableChains.ImageToTextGeneration;
using LangChain.Chains.StackableChains.ReAct;
using LangChain.Indexes;
using LangChain.Memory;
Expand Down Expand Up @@ -298,4 +299,19 @@ public static ExtractCodeChain ExtractCode(
{
return new ExtractCodeChain(inputKey, outputKey);
}

/// <summary>
///
/// </summary>
/// <param name="model"></param>
/// <param name="image"></param>
/// <param name="outputKey"></param>
/// <returns></returns>
public static ImageToTextGenerationChain GenerateImageToText(
IImageToTextModel model,
BinaryData image,
string outputKey = "text")
{
return new ImageToTextGenerationChain(model, image, outputKey);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using LangChain.Abstractions.Schema;
using LangChain.Chains.HelperChains;
using LangChain.Providers;

namespace LangChain.Chains.StackableChains.ImageToTextGeneration;

/// <summary>
///
/// </summary>
public class ImageToTextGenerationChain : BaseStackableChain
{
private readonly IImageToTextModel _model;
private readonly BinaryData _image;

/// <summary>
///
/// </summary>
/// <param name="model"></param>
/// <param name="image"></param>
/// <param name="outputKey"></param>
public ImageToTextGenerationChain(
IImageToTextModel model,
BinaryData image,
string outputKey = "text")
{
_model = model;
_image = image;
OutputKeys = new[] { outputKey };
}

/// <inheritdoc />
protected override async Task<IChainValues> InternalCall(IChainValues values)
{
values = values ?? throw new ArgumentNullException(nameof(values));

var text = await _model.GenerateTextFromImageAsync(new ImageToTextRequest { Image = _image }).ConfigureAwait(false);
values.Value[OutputKeys[0]] = text;
return values;
}
}
1 change: 1 addition & 0 deletions src/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageVersion>
<PackageVersion Include="StackExchange.Redis" Version="2.7.20" />
<PackageVersion Include="System.Memory.Data" Version="8.0.0" />
<PackageVersion Include="System.Net.Http" Version="4.3.4" />
<PackageVersion Include="System.Text.Json" Version="8.0.0" />
<PackageVersion Include="System.ValueTuple" Version="4.5.0" />
Expand Down
3 changes: 3 additions & 0 deletions src/Providers/Abstractions/src/Common/Provider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,7 @@ public abstract class Provider(string id) : Model(id), IProvider

/// <inheritdoc />
public TextToSpeechSettings? TextToSpeechSettings { get; init; }

/// <inheritdoc />
public ImageToTextSettings? ImageToTextSettings { get; init; }
}
19 changes: 19 additions & 0 deletions src/Providers/Abstractions/src/ImageToText/IImageToTextModel.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
namespace LangChain.Providers;

/// <summary>
/// Defines a large language model that can be used for image to text generation.
/// </summary>
public interface IImageToTextModel : IModel<ImageToTextSettings>

Check warning on line 6 in src/Providers/Abstractions/src/ImageToText/IImageToTextModel.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Symbol 'IImageToTextModel' is not part of the declared public API (https:/dotnet/roslyn-analyzers/blob/main/src/PublicApiAnalyzers/PublicApiAnalyzers.Help.md)
{
/// <summary>
/// Run the LLM on the given image.
/// </summary>
/// <param name="request"></param>
/// <param name="settings"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public Task<ImageToTextResponse> GenerateTextFromImageAsync(
ImageToTextRequest request,
ImageToTextSettings? settings = null,
CancellationToken cancellationToken = default);
}
19 changes: 19 additions & 0 deletions src/Providers/Abstractions/src/ImageToText/IImageToTextModel`2.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
namespace LangChain.Providers;

/// <summary>
/// Defines a large language model that can be used for image to text generation.
/// </summary>
public interface IImageToTextModel<in TRequest, TResponse, in TSettings> : IImageToTextModel

Check warning on line 6 in src/Providers/Abstractions/src/ImageToText/IImageToTextModel`2.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Symbol 'IImageToTextModel<TRequest, TResponse, TSettings>' is not part of the declared public API (https:/dotnet/roslyn-analyzers/blob/main/src/PublicApiAnalyzers/PublicApiAnalyzers.Help.md)
{
/// <summary>
/// Run the LLM on the image.
/// </summary>
/// <param name="request"></param>
/// <param name="settings"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public Task<TResponse> GenerateTextFromImageAsync(
TRequest request,
TSettings? settings = default,
CancellationToken cancellationToken = default);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using System.Text.Json.Serialization;

namespace LangChain.Providers;

public class ImageToTextGenerationResponse : List<ImageToTextGenerationResponse.GeneratedTextItem>

Check warning on line 5 in src/Providers/Abstractions/src/ImageToText/ImageToTextGenerationResponse.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Symbol 'ImageToTextGenerationResponse' is not part of the declared public API (https:/dotnet/roslyn-analyzers/blob/main/src/PublicApiAnalyzers/PublicApiAnalyzers.Help.md)

Check warning on line 5 in src/Providers/Abstractions/src/ImageToText/ImageToTextGenerationResponse.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Symbol 'implicit constructor for 'ImageToTextGenerationResponse'' is not part of the declared public API (https:/dotnet/roslyn-analyzers/blob/main/src/PublicApiAnalyzers/PublicApiAnalyzers.Help.md)
{
public sealed class GeneratedTextItem
{
/// <summary>
/// The continuated string
/// </summary>
[JsonPropertyName("generated_text")]
public string? GeneratedText { get; set; }
}
}
10 changes: 10 additions & 0 deletions src/Providers/Abstractions/src/ImageToText/ImageToTextModel.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// ReSharper disable once CheckNamespace
namespace LangChain.Providers;

public abstract class ImageToTextModel(string id) : Model<ImageToTextSettings>(id), IImageToTextModel<ImageToTextRequest, ImageToTextResponse, ImageToTextSettings>

Check warning on line 4 in src/Providers/Abstractions/src/ImageToText/ImageToTextModel.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Symbol 'ImageToTextModel' is not part of the declared public API (https:/dotnet/roslyn-analyzers/blob/main/src/PublicApiAnalyzers/PublicApiAnalyzers.Help.md)
{
public abstract Task<ImageToTextResponse> GenerateTextFromImageAsync(
ImageToTextRequest request,
ImageToTextSettings? settings = default,
CancellationToken cancellationToken = default);
}
13 changes: 13 additions & 0 deletions src/Providers/Abstractions/src/ImageToText/ImageToTextRequest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// ReSharper disable once CheckNamespace
namespace LangChain.Providers;

/// <summary>
/// Base class for image to text requests.
/// </summary>
public class ImageToTextRequest

Check warning on line 7 in src/Providers/Abstractions/src/ImageToText/ImageToTextRequest.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Symbol 'ImageToTextRequest' is not part of the declared public API (https:/dotnet/roslyn-analyzers/blob/main/src/PublicApiAnalyzers/PublicApiAnalyzers.Help.md)

Check warning on line 7 in src/Providers/Abstractions/src/ImageToText/ImageToTextRequest.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Symbol 'implicit constructor for 'ImageToTextRequest'' is not part of the declared public API (https:/dotnet/roslyn-analyzers/blob/main/src/PublicApiAnalyzers/PublicApiAnalyzers.Help.md)
{
/// <summary>
/// Image to upload.
/// </summary>
public required BinaryData Image { get; init; }
}
27 changes: 27 additions & 0 deletions src/Providers/Abstractions/src/ImageToText/ImageToTextResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// ReSharper disable once CheckNamespace
// ReSharper disable ConditionalAccessQualifierIsNonNullableAccordingToAPIContract
namespace LangChain.Providers;

#pragma warning disable CA2225

/// <summary>
///
/// </summary>
public class ImageToTextResponse

Check warning on line 10 in src/Providers/Abstractions/src/ImageToText/ImageToTextResponse.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Symbol 'ImageToTextResponse' is not part of the declared public API (https:/dotnet/roslyn-analyzers/blob/main/src/PublicApiAnalyzers/PublicApiAnalyzers.Help.md)

Check warning on line 10 in src/Providers/Abstractions/src/ImageToText/ImageToTextResponse.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Symbol 'implicit constructor for 'ImageToTextResponse'' is not part of the declared public API (https:/dotnet/roslyn-analyzers/blob/main/src/PublicApiAnalyzers/PublicApiAnalyzers.Help.md)
{
/// <summary>
///
/// </summary>
public required ImageToTextSettings UsedSettings { get; init; }

/// <summary>
///
/// </summary>
public Usage Usage { get; init; } = Usage.Empty;


/// <summary>
/// Generated text
/// </summary>
public string? Text { get; set; }
}
55 changes: 55 additions & 0 deletions src/Providers/Abstractions/src/ImageToText/ImageToTextSettings.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// ReSharper disable once CheckNamespace
namespace LangChain.Providers;

/// <summary>
/// Base class for image to text request settings.
/// </summary>
public class ImageToTextSettings
{
public static ImageToTextSettings Default { get; } = new()
{
User = string.Empty,
Endpoint = "https://api-inference.huggingface.co/models/"
};

/// <summary>
/// Unique user identifier.
/// </summary>
public string? User { get; init; }

/// <summary>
/// Endpoint url for api.
/// </summary>
public string Endpoint { get; set; }

Check warning on line 23 in src/Providers/Abstractions/src/ImageToText/ImageToTextSettings.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Non-nullable property 'Endpoint' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.


/// <summary>
/// Calculate the settings to use for the request.
/// </summary>
/// <param name="requestSettings"></param>
/// <param name="modelSettings"></param>
/// <param name="providerSettings"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public static ImageToTextSettings Calculate(
ImageToTextSettings? requestSettings,
ImageToTextSettings? modelSettings,
ImageToTextSettings? providerSettings)
{
return new ImageToTextSettings
{
User =
requestSettings?.User ??
modelSettings?.User ??
providerSettings?.User ??
Default.User ??
throw new InvalidOperationException("Default User is not set."),
Endpoint =
requestSettings?.Endpoint ??
modelSettings?.Endpoint ??
providerSettings?.Endpoint ??
Default.Endpoint ??
throw new InvalidOperationException("Default Endpoint is not set."),
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="System.Memory.Data" />
</ItemGroup>

</Project>
Loading
Loading