Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLava API Improvements #757

Merged
merged 1 commit into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions LLama/LLavaWeights.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ private LLavaWeights(SafeLlavaModelHandle weights)
{
NativeHandle = weights;
}


#region load
/// <summary>
/// Load weights into memory
/// </summary>
Expand All @@ -43,7 +44,9 @@ public static Task<LLavaWeights> LoadFromFileAsync(string mmProject, Cancellatio
{
return Task.Run(() => LoadFromFile(mmProject), token);
}
#endregion

#region embed
/// <summary>
/// Create the Image Embeddings from the bytes of an image.
/// </summary>
Expand All @@ -57,9 +60,20 @@ public static Task<LLavaWeights> LoadFromFileAsync(string mmProject, Cancellatio
/// </list>
/// </param>
/// <returns></returns>
public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image )
public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image)
{
return NativeHandle.CreateImageEmbeddings(ctxLlama, image);
}

/// <summary>
/// Create the Image Embeddings.
/// </summary>
/// <param name="image">Image in binary format (it supports jpeg format only)</param>
/// <param name="threads">Number of threads to use</param>
/// <returns>return the SafeHandle of these embeddings</returns>
public SafeLlavaImageEmbedHandle CreateImageEmbeddings(byte[] image, int threads = -1)
{
return NativeHandle.CreateImageEmbeddings(ctxLlama, image );
return NativeHandle.CreateImageEmbeddings(image, threads);
}

/// <summary>
Expand All @@ -76,10 +90,30 @@ public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, by
/// </param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, string image )
public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, string image)
{
return NativeHandle.CreateImageEmbeddings(ctxLlama, image);
}

/// <summary>
/// Create the Image Embeddings from the bytes of an image.
/// </summary>
/// <param name="image">Path to the image file. Supported formats:
/// <list type="bullet">
/// <item>JPG</item>
/// <item>PNG</item>
/// <item>BMP</item>
/// <item>TGA</item>
/// </list>
/// </param>
/// <param name="threads"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public SafeLlavaImageEmbedHandle CreateImageEmbeddings(string image, int threads = -1)
{
return NativeHandle.CreateImageEmbeddings(ctxLlama, image );
return NativeHandle.CreateImageEmbeddings(image, threads);
}
#endregion

/// <summary>
/// Eval the image embeddings
Expand Down
3 changes: 2 additions & 1 deletion LLama/Native/LLavaImageEmbed.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ namespace LLama.Native;
/// <summary>
/// LLaVa Image embeddings
/// </summary>
/// <remarks>llava_image_embed</remarks>
[StructLayout(LayoutKind.Sequential)]
unsafe public struct LLavaImageEmbed
public unsafe struct LLavaImageEmbed
{
public float* embed;
public int n_image_pos;
Expand Down
105 changes: 95 additions & 10 deletions LLama/Native/SafeLlavaImageEmbedHandle.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System;
using System;
using System.IO;


Expand All @@ -10,11 +10,39 @@ namespace LLama.Native
public sealed class SafeLlavaImageEmbedHandle
: SafeLLamaHandleBase
{
/// <summary>
/// Get the model used to create this image embedding
/// </summary>
public SafeLlavaModelHandle Model { get; private set; } = null!;

#region embed
/// <summary>
/// Create an image embed from an image file
/// </summary>
/// <param name="clip"></param>
/// <param name="ctx"></param>
/// <param name="image">Path to the image file. Supported formats:
/// <list type="bullet">
/// <item>JPG</item>
/// <item>PNG</item>
/// <item>BMP</item>
/// <item>TGA</item>
/// </list>
/// </param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public static SafeLlavaImageEmbedHandle CreateFromFileName(SafeLlavaModelHandle clip, LLamaContext ctx, string image)
{
if (!NativeApi.llava_validate_embed_size(ctx.NativeHandle, clip))
throw new InvalidOperationException($"Cannot create image embed. Embedding dim of the multimodal projector ({clip.EmbeddingDimensions}) is not equal to embedding dim of model ({ctx.EmbeddingSize})");

return CreateFromFileName(clip, image, (int)ctx.BatchThreads);
}

/// <summary>
/// Create an image embed from an image file
/// </summary>
/// <param name="ctxLlava"></param>
/// <param name="ctxLlama"></param>
/// <param name="clip"></param>
/// <param name="image">Path to the image file. Supported formats:
/// <list type="bullet">
/// <item>JPG</item>
Expand All @@ -23,25 +51,32 @@ public sealed class SafeLlavaImageEmbedHandle
/// <item>TGA</item>
/// </list>
/// </param>
/// <param name="threads"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public static SafeLlavaImageEmbedHandle CreateFromFileName( SafeLlavaModelHandle ctxLlava, LLamaContext ctxLlama, string image )
public static SafeLlavaImageEmbedHandle CreateFromFileName(SafeLlavaModelHandle clip, string image, int threads = -1)
{
if (threads <= 0)
threads = Environment.ProcessorCount / 2;

// Try to open the image file, this will check:
// - File exists (automatically throws FileNotFoundException)
// - File is readable (explicit check)
// This provides better error messages that llama.cpp, which would throw an access violation exception in both cases.
using (var fs = new FileStream(image, FileMode.Open))
if (!fs.CanRead)
throw new InvalidOperationException($"Llava image file '{image}' is not readable");
return NativeApi.llava_image_embed_make_with_filename(ctxLlava, (int) ctxLlama.BatchThreads, image);

var embed = NativeApi.llava_image_embed_make_with_filename(clip, threads, image);
embed.Model = clip;
return embed;
}

/// <summary>
/// Create an image embed from the bytes of an image.
/// </summary>
/// <param name="ctxLlava"></param>
/// <param name="ctxLlama"></param>
/// <param name="clip"></param>
/// <param name="ctx"></param>
/// <param name="image">Image bytes. Supported formats:
/// <list type="bullet">
/// <item>JPG</item>
Expand All @@ -51,17 +86,67 @@ public static SafeLlavaImageEmbedHandle CreateFromFileName( SafeLlavaModelHandle
/// </list>
/// </param>
/// <returns></returns>
public static SafeLlavaImageEmbedHandle CreateFromMemory( SafeLlavaModelHandle ctxLlava, LLamaContext ctxLlama, byte[] image )
public static SafeLlavaImageEmbedHandle CreateFromMemory(SafeLlavaModelHandle clip, LLamaContext ctx, byte[] image)
{
return NativeApi.llava_image_embed_make_with_bytes(ctxLlava, (int) ctxLlama.BatchThreads, image, image.Length);
if (!NativeApi.llava_validate_embed_size(ctx.NativeHandle, clip))
throw new InvalidOperationException($"Cannot create image embed. Embedding dim of the multimodal projector ({clip.EmbeddingDimensions}) is not equal to embedding dim of model ({ctx.EmbeddingSize})");

return CreateFromMemory(clip, image, (int)ctx.BatchThreads);
}

/// <summary>
/// Create an image embed from the bytes of an image.
/// </summary>
/// <param name="clip"></param>
/// <param name="image">Image bytes. Supported formats:
/// <list type="bullet">
/// <item>JPG</item>
/// <item>PNG</item>
/// <item>BMP</item>
/// <item>TGA</item>
/// </list>
/// </param>
/// <param name="threads"></param>
/// <returns></returns>
public static SafeLlavaImageEmbedHandle CreateFromMemory(SafeLlavaModelHandle clip, byte[] image, int threads = -1)
{
if (threads <= 0)
threads = Environment.ProcessorCount / 2;

var embed = NativeApi.llava_image_embed_make_with_bytes(clip, threads, image, image.Length);
embed.Model = clip;
return embed;
}
#endregion

/// <inheritdoc />
protected override bool ReleaseHandle()
{
NativeApi.llava_image_embed_free(DangerousGetHandle());
SetHandle(IntPtr.Zero);
return true;
}

/// <summary>
/// Copy the embeddings data to the destination span
/// </summary>
/// <param name="dest"></param>
/// <param name="index"></param>
public void GetEmbedding(Span<float> dest, int index)
{
if (index < 0)
throw new ArgumentOutOfRangeException(nameof(index), "index must be >= 0");
if (index >= Model.PatchCount)
throw new ArgumentOutOfRangeException(nameof(index), "index must be < Model.PatchCount");

unsafe
{
var embed = (LLavaImageEmbed*)DangerousGetHandle();
new Span<float>(
embed->embed + Model.EmbeddingDimensions * index,
Model.EmbeddingDimensions
).CopyTo(dest);
}
}
}
}
45 changes: 41 additions & 4 deletions LLama/Native/SafeLlavaModelHandle.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System;
using System;
using System.IO;
using System.Runtime.InteropServices;
using LLama.Exceptions;
Expand All @@ -12,6 +12,16 @@ namespace LLama.Native
public sealed class SafeLlavaModelHandle
: SafeLLamaHandleBase
{
/// <summary>
/// Get the number of dimensions in an embedding
/// </summary>
public int EmbeddingDimensions => clip_n_mmproj_embd(this);

/// <summary>
/// Get the number of "patches" in an image embedding
/// </summary>
public int PatchCount => clip_n_patches(this);

/// <inheritdoc />
protected override bool ReleaseHandle()
{
Expand All @@ -30,7 +40,6 @@ protected override bool ReleaseHandle()
/// <exception cref="RuntimeError"></exception>
public static SafeLlavaModelHandle LoadFromFile(string modelPath, int verbosity )
{

// Try to open the model file, this will check:
// - File exists (automatically throws FileNotFoundException)
// - File is readable (explicit check)
Expand All @@ -57,16 +66,38 @@ public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, st
return SafeLlavaImageEmbedHandle.CreateFromFileName(this, ctxLlama, image);
}

/// <summary>
/// Create the Image Embeddings.
/// </summary>
/// <param name="image">Image in binary format (it supports jpeg format only)</param>
/// <param name="threads">Number of threads to use</param>
/// <returns>return the SafeHandle of these embeddings</returns>
public SafeLlavaImageEmbedHandle CreateImageEmbeddings(string image, int threads = -1)
{
return SafeLlavaImageEmbedHandle.CreateFromFileName(this, image, threads);
}

/// <summary>
/// Create the Image Embeddings.
/// </summary>
/// <param name="ctxLlama">LLama Context</param>
/// <param name="image">Image in binary format (it supports jpeg format only)</param>
/// <returns>return the SafeHandle of these embeddings</returns>
public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image )
public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image)
{
return SafeLlavaImageEmbedHandle.CreateFromMemory(this, ctxLlama, image );
}

/// <summary>
/// Create the Image Embeddings.
/// </summary>
/// <param name="image">Image in binary format (it supports jpeg format only)</param>
/// <param name="threads">Number of threads to use</param>
/// <returns>return the SafeHandle of these embeddings</returns>
public SafeLlavaImageEmbedHandle CreateImageEmbeddings(byte[] image, int threads = -1)
{
return SafeLlavaImageEmbedHandle.CreateFromMemory(this, image, threads);
}

/// <summary>
/// Evaluates the image embeddings.
Expand All @@ -79,7 +110,8 @@ public bool EvalImageEmbed(LLamaContext ctxLlama, SafeLlavaImageEmbedHandle imag
{
return NativeApi.llava_eval_image_embed(ctxLlama.NativeHandle, imageEmbed, (int)ctxLlama.Params.BatchSize, ref n_past );
}


#region native API
/// <summary>
/// Load MULTI MODAL PROJECTIONS model / Clip Model
/// </summary>
Expand All @@ -96,6 +128,11 @@ public bool EvalImageEmbed(LLamaContext ctxLlama, SafeLlavaImageEmbedHandle imag
[DllImport(NativeApi.llavaLibraryName, EntryPoint = "clip_free", CallingConvention = CallingConvention.Cdecl)]
private static extern void clip_free(IntPtr ctx);

[DllImport(NativeApi.llavaLibraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int clip_n_mmproj_embd(SafeLlavaModelHandle ctx);

[DllImport(NativeApi.llavaLibraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int clip_n_patches(SafeLlavaModelHandle ctx);
#endregion
}
}
Loading