mirror of https://github.com/google/gemma.cpp.git
175 lines
5.7 KiB
C#
175 lines
5.7 KiB
C#
using System;
|
|
using System.Diagnostics;
|
|
using System.Runtime.InteropServices;
|
|
using System.Text;
|
|
namespace GemmaCpp
|
|
{
|
|
public class GemmaException : Exception
|
|
{
|
|
public GemmaException(string message) : base(message) { }
|
|
}
|
|
|
|
public class Gemma : IDisposable
|
|
{
|
|
private IntPtr _context;
|
|
private bool _disposed;
|
|
|
|
// Optional: Allow setting DLL path
|
|
public static string DllPath { get; set; } = "gemma.dll";
|
|
|
|
[DllImport("kernel32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
|
|
private static extern IntPtr LoadLibrary(string lpFileName);
|
|
|
|
static Gemma()
|
|
{
|
|
// Load DLL from specified path
|
|
if (LoadLibrary(DllPath) == IntPtr.Zero)
|
|
{
|
|
throw new DllNotFoundException($"Failed to load {DllPath}. Error: {Marshal.GetLastWin32Error()}");
|
|
}
|
|
}
|
|
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
|
private static extern IntPtr GemmaCreate(
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string tokenizerPath,
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string modelType,
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightsPath,
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightType,
|
|
int maxLength);
|
|
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
|
private static extern void GemmaDestroy(IntPtr context);
|
|
|
|
// Delegate type for token callbacks
|
|
public delegate bool TokenCallback(string token);
|
|
|
|
// Keep delegate alive for duration of calls
|
|
private GCHandle _callbackHandle;
|
|
|
|
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
|
private delegate bool GemmaTokenCallback(
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string text,
|
|
IntPtr userData);
|
|
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
|
private static extern int GemmaGenerate(
|
|
IntPtr context,
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] StringBuilder output,
|
|
int maxLength,
|
|
GemmaTokenCallback callback,
|
|
IntPtr userData);
|
|
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
|
private static extern int GemmaCountTokens(
|
|
IntPtr context,
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string text);
|
|
|
|
// Native callback delegate type
|
|
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
|
private delegate void GemmaLogCallback(
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string message,
|
|
IntPtr userData);
|
|
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
|
private static extern void GemmaSetLogCallback(
|
|
IntPtr context,
|
|
GemmaLogCallback callback,
|
|
IntPtr userData);
|
|
|
|
private GCHandle _logCallbackHandle;
|
|
|
|
public Gemma(string tokenizerPath, string modelType, string weightsPath, string weightType, int maxLength = 8192)
|
|
{
|
|
_context = GemmaCreate(tokenizerPath, modelType, weightsPath, weightType, maxLength);
|
|
if (_context == IntPtr.Zero)
|
|
{
|
|
throw new GemmaException("Failed to create Gemma context");
|
|
}
|
|
|
|
// optionally: set up logging
|
|
/*
|
|
GemmaLogCallback logCallback = (message, _) =>
|
|
{
|
|
#if UNITY_ENGINE
|
|
Debug.Log($"Gemma: {message}");
|
|
#else
|
|
Debug.WriteLine($"Gemma: {message}");
|
|
#endif
|
|
};
|
|
_logCallbackHandle = GCHandle.Alloc(logCallback);
|
|
GemmaSetLogCallback(_context, logCallback, IntPtr.Zero);
|
|
*/
|
|
}
|
|
|
|
public int CountTokens(string prompt)
|
|
{
|
|
if (_disposed)
|
|
throw new ObjectDisposedException(nameof(Gemma));
|
|
|
|
if (_context == IntPtr.Zero)
|
|
throw new GemmaException("Gemma context is invalid");
|
|
int count = GemmaCountTokens(_context, prompt);
|
|
return count;
|
|
}
|
|
|
|
public string Generate(string prompt, int maxLength = 4096)
|
|
{
|
|
return Generate(prompt, null, maxLength);
|
|
}
|
|
|
|
public string Generate(string prompt, TokenCallback callback, int maxLength = 4096)
|
|
{
|
|
if (_disposed)
|
|
throw new ObjectDisposedException(nameof(Gemma));
|
|
|
|
if (_context == IntPtr.Zero)
|
|
throw new GemmaException("Gemma context is invalid");
|
|
|
|
var output = new StringBuilder(maxLength);
|
|
GemmaTokenCallback nativeCallback = null;
|
|
|
|
if (callback != null)
|
|
{
|
|
nativeCallback = (text, _) => callback(text);
|
|
_callbackHandle = GCHandle.Alloc(nativeCallback);
|
|
}
|
|
|
|
try
|
|
{
|
|
int length = GemmaGenerate(_context, prompt, output, maxLength,
|
|
nativeCallback, IntPtr.Zero);
|
|
|
|
if (length < 0)
|
|
throw new GemmaException("Generation failed");
|
|
|
|
return output.ToString();
|
|
}
|
|
finally
|
|
{
|
|
if (_callbackHandle.IsAllocated)
|
|
_callbackHandle.Free();
|
|
}
|
|
}
|
|
|
|
public void Dispose()
|
|
{
|
|
if (!_disposed)
|
|
{
|
|
if (_context != IntPtr.Zero)
|
|
{
|
|
GemmaDestroy(_context);
|
|
_context = IntPtr.Zero;
|
|
}
|
|
if (_logCallbackHandle.IsAllocated)
|
|
_logCallbackHandle.Free();
|
|
_disposed = true;
|
|
}
|
|
}
|
|
|
|
~Gemma()
|
|
{
|
|
Dispose();
|
|
}
|
|
}
|
|
} |