diff --git a/LLama/LLavaWeights.cs b/LLama/LLavaWeights.cs index 9594dcdbb..cb9692ead 100644 --- a/LLama/LLavaWeights.cs +++ b/LLama/LLavaWeights.cs @@ -21,7 +21,8 @@ private LLavaWeights(SafeLlavaModelHandle weights) { NativeHandle = weights; } - + + #region load /// /// Load weights into memory /// @@ -43,7 +44,9 @@ public static Task LoadFromFileAsync(string mmProject, Cancellatio { return Task.Run(() => LoadFromFile(mmProject), token); } + #endregion + #region embed /// /// Create the Image Embeddings from the bytes of an image. /// @@ -57,9 +60,20 @@ public static Task LoadFromFileAsync(string mmProject, Cancellatio /// /// /// - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image ) + public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image) + { + return NativeHandle.CreateImageEmbeddings(ctxLlama, image); + } + + /// + /// Create the Image Embeddings. + /// + /// Image in binary format (it supports jpeg format only) + /// Number of threads to use + /// return the SafeHandle of these embeddings + public SafeLlavaImageEmbedHandle CreateImageEmbeddings(byte[] image, int threads = -1) { - return NativeHandle.CreateImageEmbeddings(ctxLlama, image ); + return NativeHandle.CreateImageEmbeddings(image, threads); } /// @@ -76,10 +90,30 @@ public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, by /// /// /// - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, string image ) + public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, string image) + { + return NativeHandle.CreateImageEmbeddings(ctxLlama, image); + } + + /// + /// Create the Image Embeddings from the bytes of an image. + /// + /// Path to the image file. Supported formats: + /// + /// JPG + /// PNG + /// BMP + /// TGA + /// + /// + /// + /// + /// + public SafeLlavaImageEmbedHandle CreateImageEmbeddings(string image, int threads = -1) { - return NativeHandle.CreateImageEmbeddings(ctxLlama, image ); + return NativeHandle.CreateImageEmbeddings(image, threads); } + #endregion /// /// Eval the image embeddings diff --git a/LLama/Native/LLavaImageEmbed.cs b/LLama/Native/LLavaImageEmbed.cs index 2030515ec..7704b73de 100644 --- a/LLama/Native/LLavaImageEmbed.cs +++ b/LLama/Native/LLavaImageEmbed.cs @@ -5,8 +5,9 @@ namespace LLama.Native; /// /// LLaVa Image embeddings /// +/// llava_image_embed [StructLayout(LayoutKind.Sequential)] -unsafe public struct LLavaImageEmbed +public unsafe struct LLavaImageEmbed { public float* embed; public int n_image_pos; diff --git a/LLama/Native/SafeLlavaImageEmbedHandle.cs b/LLama/Native/SafeLlavaImageEmbedHandle.cs index aa6da9e0e..77b4eaf66 100644 --- a/LLama/Native/SafeLlavaImageEmbedHandle.cs +++ b/LLama/Native/SafeLlavaImageEmbedHandle.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.IO; @@ -10,11 +10,39 @@ namespace LLama.Native public sealed class SafeLlavaImageEmbedHandle : SafeLLamaHandleBase { + /// + /// Get the model used to create this image embedding + /// + public SafeLlavaModelHandle Model { get; private set; } = null!; + + #region embed + /// + /// Create an image embed from an image file + /// + /// + /// + /// Path to the image file. Supported formats: + /// + /// JPG + /// PNG + /// BMP + /// TGA + /// + /// + /// + /// + public static SafeLlavaImageEmbedHandle CreateFromFileName(SafeLlavaModelHandle clip, LLamaContext ctx, string image) + { + if (!NativeApi.llava_validate_embed_size(ctx.NativeHandle, clip)) + throw new InvalidOperationException($"Cannot create image embed. Embedding dim of the multimodal projector ({clip.EmbeddingDimensions}) is not equal to embedding dim of model ({ctx.EmbeddingSize})"); + + return CreateFromFileName(clip, image, (int)ctx.BatchThreads); + } + /// /// Create an image embed from an image file /// - /// - /// + /// /// Path to the image file. Supported formats: /// /// JPG @@ -23,10 +51,14 @@ public sealed class SafeLlavaImageEmbedHandle /// TGA /// /// + /// /// /// - public static SafeLlavaImageEmbedHandle CreateFromFileName( SafeLlavaModelHandle ctxLlava, LLamaContext ctxLlama, string image ) + public static SafeLlavaImageEmbedHandle CreateFromFileName(SafeLlavaModelHandle clip, string image, int threads = -1) { + if (threads <= 0) + threads = Environment.ProcessorCount / 2; + // Try to open the image file, this will check: // - File exists (automatically throws FileNotFoundException) // - File is readable (explicit check) @@ -34,14 +66,17 @@ public static SafeLlavaImageEmbedHandle CreateFromFileName( SafeLlavaModelHandle using (var fs = new FileStream(image, FileMode.Open)) if (!fs.CanRead) throw new InvalidOperationException($"Llava image file '{image}' is not readable"); - return NativeApi.llava_image_embed_make_with_filename(ctxLlava, (int) ctxLlama.BatchThreads, image); + + var embed = NativeApi.llava_image_embed_make_with_filename(clip, threads, image); + embed.Model = clip; + return embed; } - + /// /// Create an image embed from the bytes of an image. /// - /// - /// + /// + /// /// Image bytes. Supported formats: /// /// JPG @@ -51,11 +86,39 @@ public static SafeLlavaImageEmbedHandle CreateFromFileName( SafeLlavaModelHandle /// /// /// - public static SafeLlavaImageEmbedHandle CreateFromMemory( SafeLlavaModelHandle ctxLlava, LLamaContext ctxLlama, byte[] image ) + public static SafeLlavaImageEmbedHandle CreateFromMemory(SafeLlavaModelHandle clip, LLamaContext ctx, byte[] image) { - return NativeApi.llava_image_embed_make_with_bytes(ctxLlava, (int) ctxLlama.BatchThreads, image, image.Length); + if (!NativeApi.llava_validate_embed_size(ctx.NativeHandle, clip)) + throw new InvalidOperationException($"Cannot create image embed. Embedding dim of the multimodal projector ({clip.EmbeddingDimensions}) is not equal to embedding dim of model ({ctx.EmbeddingSize})"); + + return CreateFromMemory(clip, image, (int)ctx.BatchThreads); } + /// + /// Create an image embed from the bytes of an image. + /// + /// + /// Image bytes. Supported formats: + /// + /// JPG + /// PNG + /// BMP + /// TGA + /// + /// + /// + /// + public static SafeLlavaImageEmbedHandle CreateFromMemory(SafeLlavaModelHandle clip, byte[] image, int threads = -1) + { + if (threads <= 0) + threads = Environment.ProcessorCount / 2; + + var embed = NativeApi.llava_image_embed_make_with_bytes(clip, threads, image, image.Length); + embed.Model = clip; + return embed; + } + #endregion + /// protected override bool ReleaseHandle() { @@ -63,5 +126,27 @@ protected override bool ReleaseHandle() SetHandle(IntPtr.Zero); return true; } + + /// + /// Copy the embeddings data to the destination span + /// + /// + /// + public void GetEmbedding(Span dest, int index) + { + if (index < 0) + throw new ArgumentOutOfRangeException(nameof(index), "index must be >= 0"); + if (index >= Model.PatchCount) + throw new ArgumentOutOfRangeException(nameof(index), "index must be < Model.PatchCount"); + + unsafe + { + var embed = (LLavaImageEmbed*)DangerousGetHandle(); + new Span( + embed->embed + Model.EmbeddingDimensions * index, + Model.EmbeddingDimensions + ).CopyTo(dest); + } + } } } diff --git a/LLama/Native/SafeLlavaModelHandle.cs b/LLama/Native/SafeLlavaModelHandle.cs index fd898b536..9bc1ec8d2 100644 --- a/LLama/Native/SafeLlavaModelHandle.cs +++ b/LLama/Native/SafeLlavaModelHandle.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.IO; using System.Runtime.InteropServices; using LLama.Exceptions; @@ -12,6 +12,16 @@ namespace LLama.Native public sealed class SafeLlavaModelHandle : SafeLLamaHandleBase { + /// + /// Get the number of dimensions in an embedding + /// + public int EmbeddingDimensions => clip_n_mmproj_embd(this); + + /// + /// Get the number of "patches" in an image embedding + /// + public int PatchCount => clip_n_patches(this); + /// protected override bool ReleaseHandle() { @@ -30,7 +40,6 @@ protected override bool ReleaseHandle() /// public static SafeLlavaModelHandle LoadFromFile(string modelPath, int verbosity ) { - // Try to open the model file, this will check: // - File exists (automatically throws FileNotFoundException) // - File is readable (explicit check) @@ -57,16 +66,38 @@ public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, st return SafeLlavaImageEmbedHandle.CreateFromFileName(this, ctxLlama, image); } + /// + /// Create the Image Embeddings. + /// + /// Image in binary format (it supports jpeg format only) + /// Number of threads to use + /// return the SafeHandle of these embeddings + public SafeLlavaImageEmbedHandle CreateImageEmbeddings(string image, int threads = -1) + { + return SafeLlavaImageEmbedHandle.CreateFromFileName(this, image, threads); + } + /// /// Create the Image Embeddings. /// /// LLama Context /// Image in binary format (it supports jpeg format only) /// return the SafeHandle of these embeddings - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image ) + public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image) { return SafeLlavaImageEmbedHandle.CreateFromMemory(this, ctxLlama, image ); } + + /// + /// Create the Image Embeddings. + /// + /// Image in binary format (it supports jpeg format only) + /// Number of threads to use + /// return the SafeHandle of these embeddings + public SafeLlavaImageEmbedHandle CreateImageEmbeddings(byte[] image, int threads = -1) + { + return SafeLlavaImageEmbedHandle.CreateFromMemory(this, image, threads); + } /// /// Evaluates the image embeddings. @@ -79,7 +110,8 @@ public bool EvalImageEmbed(LLamaContext ctxLlama, SafeLlavaImageEmbedHandle imag { return NativeApi.llava_eval_image_embed(ctxLlama.NativeHandle, imageEmbed, (int)ctxLlama.Params.BatchSize, ref n_past ); } - + + #region native API /// /// Load MULTI MODAL PROJECTIONS model / Clip Model /// @@ -96,6 +128,11 @@ public bool EvalImageEmbed(LLamaContext ctxLlama, SafeLlavaImageEmbedHandle imag [DllImport(NativeApi.llavaLibraryName, EntryPoint = "clip_free", CallingConvention = CallingConvention.Cdecl)] private static extern void clip_free(IntPtr ctx); + [DllImport(NativeApi.llavaLibraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int clip_n_mmproj_embd(SafeLlavaModelHandle ctx); + [DllImport(NativeApi.llavaLibraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int clip_n_patches(SafeLlavaModelHandle ctx); + #endregion } }