Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tokenizer Fixes For Issue 430 #433

Merged
merged 1 commit into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions LLama.Unittest/LLamaContextTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ public void Tokenize()
Assert.Equal(new LLamaToken[] { 1, 450, 4996, 17354, 1701, 29916 }, tokens);
}

[Fact]
public void TokenizeNewline()
{
var tokens = _context.Tokenize("\n");

Assert.Equal(new LLamaToken[] { 1, 29871, 13 }, tokens);
}

[Fact]
public void TokenizeWithoutBOS()
{
Expand Down
10 changes: 9 additions & 1 deletion LLama/Native/LLamaToken.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
using System.Runtime.InteropServices;
using System.Diagnostics;
using System.Runtime.InteropServices;

namespace LLama.Native;

/// <summary>
/// A single token
/// </summary>
[StructLayout(LayoutKind.Sequential)]
[DebuggerDisplay("Value")]
public readonly record struct LLamaToken
{
/// <summary>
Expand Down Expand Up @@ -35,4 +37,10 @@ private LLamaToken(int value)
/// <param name="value"></param>
/// <returns></returns>
public static implicit operator LLamaToken(int value) => new(value);

/// <inheritdoc />
public override string ToString()
{
return Value.ToString();
}
}
32 changes: 1 addition & 31 deletions LLama/Native/SafeLLamaContextHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -155,37 +155,7 @@ public Span<float> GetLogitsIth(int i)
/// <exception cref="RuntimeError"></exception>
public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding)
{
ThrowIfDisposed();

if (string.IsNullOrEmpty(text) && !add_bos)
return Array.Empty<LLamaToken>();

// Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't
// possibly be more than this.
var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0);

// "Rent" an array to write results into (avoiding an allocation of a large array)
var temporaryArray = ArrayPool<LLamaToken>.Shared.Rent(count);
try
{
// Do the actual conversion
var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos, special);
if (n < 0)
{
throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +
"specify the encoding.");
}

// Copy the results from the rented into an array which is exactly the right size
var result = new LLamaToken[n];
Array.ConstrainedCopy(temporaryArray, 0, result, 0, n);

return result;
}
finally
{
ArrayPool<LLamaToken>.Shared.Return(temporaryArray);
}
return ThrowIfDisposed().Tokenize(text, add_bos, special, encoding);
}

/// <summary>
Expand Down
47 changes: 27 additions & 20 deletions LLama/Native/SafeLlamaModelHandle.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.InteropServices;
Expand Down Expand Up @@ -172,34 +173,40 @@ internal Span<char> TokensToSpan(IReadOnlyList<LLamaToken> tokens, Span<char> de
/// <returns></returns>
public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding)
{
// Early exit if there's no work to do
if (text == "" && !add_bos)
return Array.Empty<LLamaToken>();

// Convert string to bytes, adding one extra byte to the end (null terminator)
var bytesCount = encoding.GetByteCount(text);
var bytes = new byte[bytesCount + 1];
unsafe
var bytes = ArrayPool<byte>.Shared.Rent(bytesCount + 1);
try
{
fixed (char* charPtr = text)
fixed (byte* bytePtr = &bytes[0])
unsafe
{
encoding.GetBytes(charPtr, text.Length, bytePtr, bytes.Length);
}
}

unsafe
{
fixed (byte* bytesPtr = &bytes[0])
{
// Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space)
var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (LLamaToken*)IntPtr.Zero, 0, add_bos, special);

// Tokenize again, this time outputting into an array of exactly the right size
var tokens = new LLamaToken[count];
fixed (LLamaToken* tokensPtr = &tokens[0])
fixed (char* textPtr = text)
fixed (byte* bytesPtr = bytes)
{
NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special);
return tokens;
// Convert text into bytes
encoding.GetBytes(textPtr, text.Length, bytesPtr, bytes.Length);

// Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space)
var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (LLamaToken*)IntPtr.Zero, 0, add_bos, special);

// Tokenize again, this time outputting into an array of exactly the right size
var tokens = new LLamaToken[count];
fixed (LLamaToken* tokensPtr = tokens)
{
NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special);
return tokens;
}
}
}
}
finally
{
ArrayPool<byte>.Shared.Return(bytes, true);
}
}
#endregion

Expand Down
Loading