Skip to content

Commit

Permalink
- Added new ways to create an image embedding which don't require a c…
Browse files Browse the repository at this point in the history
…ontext

 - Added a way to get the embedding data directly from an image embed (one embedding vector at a time)
  • Loading branch information
martindevans committed May 26, 2024
1 parent 0017912 commit b56a50a
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 20 deletions.
44 changes: 39 additions & 5 deletions LLama/LLavaWeights.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ private LLavaWeights(SafeLlavaModelHandle weights)
{
NativeHandle = weights;
}


#region load
/// <summary>
/// Load weights into memory
/// </summary>
Expand All @@ -43,7 +44,9 @@ public static Task<LLavaWeights> LoadFromFileAsync(string mmProject, Cancellatio
{
return Task.Run(() => LoadFromFile(mmProject), token);
}
#endregion

#region embed
/// <summary>
/// Create the Image Embeddings from the bytes of an image.
/// </summary>
Expand All @@ -57,9 +60,20 @@ public static Task<LLavaWeights> LoadFromFileAsync(string mmProject, Cancellatio
/// </list>
/// </param>
/// <returns></returns>
public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image )
public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image)
{
return NativeHandle.CreateImageEmbeddings(ctxLlama, image);
}

/// <summary>
/// Create the Image Embeddings.
/// </summary>
/// <param name="image">Image in binary format (it supports jpeg format only)</param>
/// <param name="threads">Number of threads to use</param>
/// <returns>return the SafeHandle of these embeddings</returns>
public SafeLlavaImageEmbedHandle CreateImageEmbeddings(byte[] image, int threads = -1)
{
return NativeHandle.CreateImageEmbeddings(ctxLlama, image );
return NativeHandle.CreateImageEmbeddings(image, threads);
}

/// <summary>
Expand All @@ -76,10 +90,30 @@ public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, by
/// </param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, string image )
public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, string image)
{
return NativeHandle.CreateImageEmbeddings(ctxLlama, image);
}

/// <summary>
/// Create the Image Embeddings from the bytes of an image.
/// </summary>
/// <param name="image">Path to the image file. Supported formats:
/// <list type="bullet">
/// <item>JPG</item>
/// <item>PNG</item>
/// <item>BMP</item>
/// <item>TGA</item>
/// </list>
/// </param>
/// <param name="threads"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public SafeLlavaImageEmbedHandle CreateImageEmbeddings(string image, int threads = -1)
{
return NativeHandle.CreateImageEmbeddings(ctxLlama, image );
return NativeHandle.CreateImageEmbeddings(image, threads);
}
#endregion

/// <summary>
/// Eval the image embeddings
Expand Down
3 changes: 2 additions & 1 deletion LLama/Native/LLavaImageEmbed.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ namespace LLama.Native;
/// <summary>
/// LLaVa Image embeddings
/// </summary>
/// <remarks>llava_image_embed</remarks>
[StructLayout(LayoutKind.Sequential)]
unsafe public struct LLavaImageEmbed
public unsafe struct LLavaImageEmbed
{
public float* embed;
public int n_image_pos;
Expand Down
105 changes: 95 additions & 10 deletions LLama/Native/SafeLlavaImageEmbedHandle.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System;
using System;
using System.IO;


Expand All @@ -10,11 +10,39 @@ namespace LLama.Native
public sealed class SafeLlavaImageEmbedHandle
: SafeLLamaHandleBase
{
/// <summary>
/// Get the model used to create this image embedding
/// </summary>
public SafeLlavaModelHandle Model { get; private set; } = null!;

#region embed
/// <summary>
/// Create an image embed from an image file
/// </summary>
/// <param name="clip"></param>
/// <param name="ctx"></param>
/// <param name="image">Path to the image file. Supported formats:
/// <list type="bullet">
/// <item>JPG</item>
/// <item>PNG</item>
/// <item>BMP</item>
/// <item>TGA</item>
/// </list>
/// </param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public static SafeLlavaImageEmbedHandle CreateFromFileName(SafeLlavaModelHandle clip, LLamaContext ctx, string image)
{
if (!NativeApi.llava_validate_embed_size(ctx.NativeHandle, clip))
throw new InvalidOperationException($"Cannot create image embed. Embedding dim of the multimodal projector ({clip.EmbeddingDimensions}) is not equal to embedding dim of model ({ctx.EmbeddingSize})");

return CreateFromFileName(clip, image, (int)ctx.BatchThreads);
}

/// <summary>
/// Create an image embed from an image file
/// </summary>
/// <param name="ctxLlava"></param>
/// <param name="ctxLlama"></param>
/// <param name="clip"></param>
/// <param name="image">Path to the image file. Supported formats:
/// <list type="bullet">
/// <item>JPG</item>
Expand All @@ -23,25 +51,32 @@ public sealed class SafeLlavaImageEmbedHandle
/// <item>TGA</item>
/// </list>
/// </param>
/// <param name="threads"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public static SafeLlavaImageEmbedHandle CreateFromFileName( SafeLlavaModelHandle ctxLlava, LLamaContext ctxLlama, string image )
public static SafeLlavaImageEmbedHandle CreateFromFileName(SafeLlavaModelHandle clip, string image, int threads = -1)
{
if (threads <= 0)
threads = Environment.ProcessorCount / 2;

// Try to open the image file, this will check:
// - File exists (automatically throws FileNotFoundException)
// - File is readable (explicit check)
// This provides better error messages that llama.cpp, which would throw an access violation exception in both cases.
using (var fs = new FileStream(image, FileMode.Open))
if (!fs.CanRead)
throw new InvalidOperationException($"Llava image file '{image}' is not readable");
return NativeApi.llava_image_embed_make_with_filename(ctxLlava, (int) ctxLlama.BatchThreads, image);

var embed = NativeApi.llava_image_embed_make_with_filename(clip, threads, image);
embed.Model = clip;
return embed;
}

/// <summary>
/// Create an image embed from the bytes of an image.
/// </summary>
/// <param name="ctxLlava"></param>
/// <param name="ctxLlama"></param>
/// <param name="clip"></param>
/// <param name="ctx"></param>
/// <param name="image">Image bytes. Supported formats:
/// <list type="bullet">
/// <item>JPG</item>
Expand All @@ -51,17 +86,67 @@ public static SafeLlavaImageEmbedHandle CreateFromFileName( SafeLlavaModelHandle
/// </list>
/// </param>
/// <returns></returns>
public static SafeLlavaImageEmbedHandle CreateFromMemory( SafeLlavaModelHandle ctxLlava, LLamaContext ctxLlama, byte[] image )
public static SafeLlavaImageEmbedHandle CreateFromMemory(SafeLlavaModelHandle clip, LLamaContext ctx, byte[] image)
{
return NativeApi.llava_image_embed_make_with_bytes(ctxLlava, (int) ctxLlama.BatchThreads, image, image.Length);
if (!NativeApi.llava_validate_embed_size(ctx.NativeHandle, clip))
throw new InvalidOperationException($"Cannot create image embed. Embedding dim of the multimodal projector ({clip.EmbeddingDimensions}) is not equal to embedding dim of model ({ctx.EmbeddingSize})");

return CreateFromMemory(clip, image, (int)ctx.BatchThreads);
}

/// <summary>
/// Create an image embed from the bytes of an image.
/// </summary>
/// <param name="clip"></param>
/// <param name="image">Image bytes. Supported formats:
/// <list type="bullet">
/// <item>JPG</item>
/// <item>PNG</item>
/// <item>BMP</item>
/// <item>TGA</item>
/// </list>
/// </param>
/// <param name="threads"></param>
/// <returns></returns>
public static SafeLlavaImageEmbedHandle CreateFromMemory(SafeLlavaModelHandle clip, byte[] image, int threads = -1)
{
if (threads <= 0)
threads = Environment.ProcessorCount / 2;

var embed = NativeApi.llava_image_embed_make_with_bytes(clip, threads, image, image.Length);
embed.Model = clip;
return embed;
}
#endregion

/// <inheritdoc />
protected override bool ReleaseHandle()
{
NativeApi.llava_image_embed_free(DangerousGetHandle());
SetHandle(IntPtr.Zero);
return true;
}

/// <summary>
/// Copy the embeddings data to the destination span
/// </summary>
/// <param name="dest"></param>
/// <param name="index"></param>
public void GetEmbedding(Span<float> dest, int index)
{
if (index < 0)
throw new ArgumentOutOfRangeException(nameof(index), "index must be >= 0");
if (index >= Model.PatchCount)
throw new ArgumentOutOfRangeException(nameof(index), "index must be < Model.PatchCount");

unsafe
{
var embed = (LLavaImageEmbed*)DangerousGetHandle();
new Span<float>(
embed->embed + Model.EmbeddingDimensions * index,
Model.EmbeddingDimensions
).CopyTo(dest);
}
}
}
}
45 changes: 41 additions & 4 deletions LLama/Native/SafeLlavaModelHandle.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System;
using System;
using System.IO;
using System.Runtime.InteropServices;
using LLama.Exceptions;
Expand All @@ -12,6 +12,16 @@ namespace LLama.Native
public sealed class SafeLlavaModelHandle
: SafeLLamaHandleBase
{
/// <summary>
/// Get the number of dimensions in an embedding
/// </summary>
public int EmbeddingDimensions => clip_n_mmproj_embd(this);

/// <summary>
/// Get the number of "patches" in an image embedding
/// </summary>
public int PatchCount => clip_n_patches(this);

/// <inheritdoc />
protected override bool ReleaseHandle()
{
Expand All @@ -30,7 +40,6 @@ protected override bool ReleaseHandle()
/// <exception cref="RuntimeError"></exception>
public static SafeLlavaModelHandle LoadFromFile(string modelPath, int verbosity )
{

// Try to open the model file, this will check:
// - File exists (automatically throws FileNotFoundException)
// - File is readable (explicit check)
Expand All @@ -57,16 +66,38 @@ public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, st
return SafeLlavaImageEmbedHandle.CreateFromFileName(this, ctxLlama, image);
}

/// <summary>
/// Create the Image Embeddings.
/// </summary>
/// <param name="image">Image in binary format (it supports jpeg format only)</param>
/// <param name="threads">Number of threads to use</param>
/// <returns>return the SafeHandle of these embeddings</returns>
public SafeLlavaImageEmbedHandle CreateImageEmbeddings(string image, int threads = -1)
{
return SafeLlavaImageEmbedHandle.CreateFromFileName(this, image, threads);
}

/// <summary>
/// Create the Image Embeddings.
/// </summary>
/// <param name="ctxLlama">LLama Context</param>
/// <param name="image">Image in binary format (it supports jpeg format only)</param>
/// <returns>return the SafeHandle of these embeddings</returns>
public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image )
public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image)
{
return SafeLlavaImageEmbedHandle.CreateFromMemory(this, ctxLlama, image );
}

/// <summary>
/// Create the Image Embeddings.
/// </summary>
/// <param name="image">Image in binary format (it supports jpeg format only)</param>
/// <param name="threads">Number of threads to use</param>
/// <returns>return the SafeHandle of these embeddings</returns>
public SafeLlavaImageEmbedHandle CreateImageEmbeddings(byte[] image, int threads = -1)
{
return SafeLlavaImageEmbedHandle.CreateFromMemory(this, image, threads);
}

/// <summary>
/// Evaluates the image embeddings.
Expand All @@ -79,7 +110,8 @@ public bool EvalImageEmbed(LLamaContext ctxLlama, SafeLlavaImageEmbedHandle imag
{
return NativeApi.llava_eval_image_embed(ctxLlama.NativeHandle, imageEmbed, (int)ctxLlama.Params.BatchSize, ref n_past );
}


#region native API
/// <summary>
/// Load MULTI MODAL PROJECTIONS model / Clip Model
/// </summary>
Expand All @@ -96,6 +128,11 @@ public bool EvalImageEmbed(LLamaContext ctxLlama, SafeLlavaImageEmbedHandle imag
[DllImport(NativeApi.llavaLibraryName, EntryPoint = "clip_free", CallingConvention = CallingConvention.Cdecl)]
private static extern void clip_free(IntPtr ctx);

[DllImport(NativeApi.llavaLibraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int clip_n_mmproj_embd(SafeLlavaModelHandle ctx);

[DllImport(NativeApi.llavaLibraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int clip_n_patches(SafeLlavaModelHandle ctx);
#endregion
}
}

0 comments on commit b56a50a

Please sign in to comment.