mirror of https://github.com/google/gemma.cpp.git
474 lines
18 KiB
C#
474 lines
18 KiB
C#
// Copyright 2025 Google LLC
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// https://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
using System;
|
|
using System.Diagnostics;
|
|
using System.Runtime.InteropServices;
|
|
using System.Text;
|
|
namespace GemmaCpp
|
|
{
|
|
public class GemmaException : Exception
|
|
{
|
|
public GemmaException(string message) : base(message) { }
|
|
}
|
|
|
|
public class Gemma : IDisposable
|
|
{
|
|
private IntPtr _context;
|
|
private bool _disposed;
|
|
|
|
// Optional: Allow setting DLL path
|
|
public static string DllPath { get; set; } = "gemma.dll";
|
|
|
|
[DllImport("kernel32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
|
|
private static extern IntPtr LoadLibrary(string lpFileName);
|
|
|
|
static Gemma()
|
|
{
|
|
// Load DLL from specified path
|
|
if (LoadLibrary(DllPath) == IntPtr.Zero)
|
|
{
|
|
throw new DllNotFoundException($"Failed to load {DllPath}. Error: {Marshal.GetLastWin32Error()}");
|
|
}
|
|
}
|
|
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
|
private static extern IntPtr GemmaCreate(
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string tokenizerPath,
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string modelType,
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightsPath,
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightType,
|
|
int maxGeneratedTokens);
|
|
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
|
private static extern void GemmaDestroy(IntPtr context);
|
|
|
|
// Delegate type for token callbacks
|
|
public delegate bool TokenCallback(string token);
|
|
|
|
// Keep delegate alive for duration of calls
|
|
private GCHandle _callbackHandle;
|
|
|
|
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
|
private delegate bool GemmaTokenCallback(
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string text,
|
|
IntPtr userData);
|
|
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
|
private static extern int GemmaGenerate(
|
|
IntPtr context,
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
|
|
[Out] byte[] output,
|
|
int maxOutputChars,
|
|
GemmaTokenCallback callback,
|
|
IntPtr userData);
|
|
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
|
private static extern int GemmaGenerateMultimodal(
|
|
IntPtr context,
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
|
|
IntPtr image_data, // Renamed param to match C API
|
|
int image_width, // Added dimension
|
|
int image_height, // Added dimension
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] StringBuilder output, // Output should be StringBuilder for multimodal
|
|
int maxOutputChars,
|
|
GemmaTokenCallback callback,
|
|
IntPtr userData);
|
|
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
|
private static extern int GemmaCountTokens(
|
|
IntPtr context,
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string text);
|
|
|
|
// Configuration function imports
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
|
private static extern void GemmaSetMaxGeneratedTokens(IntPtr context, int value);
|
|
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
|
private static extern void GemmaSetMultiturn(IntPtr context, int value);
|
|
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
|
private static extern void GemmaSetTemperature(IntPtr context, float value);
|
|
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
|
private static extern void GemmaSetTopK(IntPtr context, int value);
|
|
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
|
private static extern void GemmaSetDeterministic(IntPtr context, int value);
|
|
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
|
private static extern void GemmaSetPrefillTbatchSize(IntPtr context, int value);
|
|
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaResetConversation")]
|
|
private static extern void GemmaResetConversation(IntPtr context);
|
|
|
|
// Conversation management function imports
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaCreateConversation")]
|
|
private static extern int GemmaCreateConversation(
|
|
IntPtr context,
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
|
|
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaSwitchConversation")]
|
|
private static extern int GemmaSwitchConversation(
|
|
IntPtr context,
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
|
|
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaDeleteConversation")]
|
|
private static extern int GemmaDeleteConversation(
|
|
IntPtr context,
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
|
|
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaHasConversation")]
|
|
private static extern int GemmaHasConversation(
|
|
IntPtr context,
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
|
|
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaGetCurrentConversation")]
|
|
[return: MarshalAs(UnmanagedType.LPUTF8Str)] // Marshal the const char* return value as a string
|
|
private static extern string GemmaGetCurrentConversation(IntPtr context);
|
|
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaSaveConversation")]
|
|
private static extern void GemmaSaveConversation(IntPtr context);
|
|
|
|
// Native callback delegate type
|
|
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
|
private delegate void GemmaLogCallback(
|
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string message,
|
|
IntPtr userData);
|
|
|
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
|
private static extern void GemmaSetLogCallback(
|
|
IntPtr context,
|
|
GemmaLogCallback callback,
|
|
IntPtr userData);
|
|
|
|
private GCHandle _logCallbackHandle;
|
|
private bool _loggingEnabled = false;
|
|
|
|
public Gemma(string tokenizerPath, string weightsPath, int maxGeneratedTokens = 8192)
|
|
{
|
|
_context = GemmaCreate(tokenizerPath, weightsPath, maxGeneratedTokens);
|
|
if (_context == IntPtr.Zero)
|
|
{
|
|
throw new GemmaException("Failed to create Gemma context");
|
|
}
|
|
}
|
|
|
|
// Enable debug logging
|
|
public void EnableLogging(bool enable = true)
|
|
{
|
|
if (enable && !_loggingEnabled)
|
|
{
|
|
GemmaLogCallback logCallback = (message, _) =>
|
|
{
|
|
Debug.WriteLine($"Gemma: {message}");
|
|
};
|
|
_logCallbackHandle = GCHandle.Alloc(logCallback);
|
|
GemmaSetLogCallback(_context, logCallback, IntPtr.Zero);
|
|
_loggingEnabled = true;
|
|
}
|
|
else if (!enable && _loggingEnabled)
|
|
{
|
|
if (_logCallbackHandle.IsAllocated)
|
|
_logCallbackHandle.Free();
|
|
GemmaSetLogCallback(_context, null, IntPtr.Zero);
|
|
_loggingEnabled = false;
|
|
}
|
|
}
|
|
|
|
// Configuration methods
|
|
public void SetMultiturn(bool enable)
|
|
{
|
|
if (_disposed)
|
|
throw new ObjectDisposedException(nameof(Gemma));
|
|
|
|
if (_context == IntPtr.Zero)
|
|
throw new GemmaException("Gemma context is invalid");
|
|
|
|
GemmaSetMultiturn(_context, enable ? 1 : 0);
|
|
Debug.WriteLine($"Gemma: Set multiturn to {(enable ? "enabled" : "disabled")}");
|
|
}
|
|
|
|
public void SetTemperature(float temperature)
|
|
{
|
|
if (_disposed)
|
|
throw new ObjectDisposedException(nameof(Gemma));
|
|
|
|
if (_context == IntPtr.Zero)
|
|
throw new GemmaException("Gemma context is invalid");
|
|
|
|
GemmaSetTemperature(_context, temperature);
|
|
Debug.WriteLine($"Gemma: Set temperature to {temperature}");
|
|
}
|
|
|
|
public void SetTopK(int topK)
|
|
{
|
|
if (_disposed)
|
|
throw new ObjectDisposedException(nameof(Gemma));
|
|
|
|
if (_context == IntPtr.Zero)
|
|
throw new GemmaException("Gemma context is invalid");
|
|
|
|
GemmaSetTopK(_context, topK);
|
|
Debug.WriteLine($"Gemma: Set topK to {topK}");
|
|
}
|
|
|
|
public void SetDeterministic(bool deterministic)
|
|
{
|
|
if (_disposed)
|
|
throw new ObjectDisposedException(nameof(Gemma));
|
|
|
|
if (_context == IntPtr.Zero)
|
|
throw new GemmaException("Gemma context is invalid");
|
|
|
|
GemmaSetDeterministic(_context, deterministic ? 1 : 0);
|
|
Debug.WriteLine($"Gemma: Set deterministic to {(deterministic ? "true" : "false")}");
|
|
}
|
|
|
|
// Renamed public method
|
|
public void ResetConversation()
|
|
{
|
|
if (_disposed)
|
|
throw new ObjectDisposedException(nameof(Gemma));
|
|
|
|
if (_context == IntPtr.Zero)
|
|
throw new GemmaException("Gemma context is invalid");
|
|
|
|
GemmaResetConversation(_context); // Call P/Invoke method
|
|
Debug.WriteLine("Gemma: Reset active conversation");
|
|
}
|
|
|
|
// Conversation management methods
|
|
public bool CreateConversation(string conversationName)
|
|
{
|
|
if (_disposed)
|
|
throw new ObjectDisposedException(nameof(Gemma));
|
|
|
|
if (_context == IntPtr.Zero)
|
|
throw new GemmaException("Gemma context is invalid");
|
|
|
|
bool result = GemmaCreateConversation(_context, conversationName) != 0; // Call P/Invoke method
|
|
Debug.WriteLine($"Gemma: Create conversation '{conversationName}' - {(result ? "succeeded" : "failed")}");
|
|
return result;
|
|
}
|
|
|
|
public bool SwitchConversation(string conversationName)
|
|
{
|
|
if (_disposed)
|
|
throw new ObjectDisposedException(nameof(Gemma));
|
|
|
|
if (_context == IntPtr.Zero)
|
|
throw new GemmaException("Gemma context is invalid");
|
|
|
|
bool result = GemmaSwitchConversation(_context, conversationName) != 0; // Call P/Invoke method
|
|
Debug.WriteLine($"Gemma: Switch to conversation '{conversationName}' - {(result ? "succeeded" : "failed")}");
|
|
return result;
|
|
}
|
|
|
|
public bool DeleteConversation(string conversationName)
|
|
{
|
|
if (_disposed)
|
|
throw new ObjectDisposedException(nameof(Gemma));
|
|
|
|
if (_context == IntPtr.Zero)
|
|
throw new GemmaException("Gemma context is invalid");
|
|
|
|
bool result = GemmaDeleteConversation(_context, conversationName) != 0; // Call P/Invoke method
|
|
Debug.WriteLine($"Gemma: Delete conversation '{conversationName}' - {(result ? "succeeded" : "failed")}");
|
|
return result;
|
|
}
|
|
|
|
public bool HasConversation(string conversationName)
|
|
{
|
|
if (_disposed)
|
|
throw new ObjectDisposedException(nameof(Gemma));
|
|
|
|
if (_context == IntPtr.Zero)
|
|
throw new GemmaException("Gemma context is invalid");
|
|
|
|
bool result = GemmaHasConversation(_context, conversationName) != 0; // Call P/Invoke method
|
|
Debug.WriteLine($"Gemma: Has conversation '{conversationName}' - {result}");
|
|
return result;
|
|
}
|
|
|
|
public string GetCurrentConversation()
|
|
{
|
|
if (_disposed)
|
|
throw new ObjectDisposedException(nameof(Gemma));
|
|
|
|
if (_context == IntPtr.Zero)
|
|
throw new GemmaException("Gemma context is invalid");
|
|
|
|
string currentConversation = GemmaGetCurrentConversation(_context); // Call P/Invoke method
|
|
Debug.WriteLine($"Gemma: Current conversation is '{currentConversation}'");
|
|
return currentConversation;
|
|
}
|
|
|
|
public void SaveConversation()
|
|
{
|
|
if (_disposed)
|
|
throw new ObjectDisposedException(nameof(Gemma));
|
|
|
|
if (_context == IntPtr.Zero)
|
|
throw new GemmaException("Gemma context is invalid");
|
|
|
|
GemmaSaveConversation(_context);
|
|
Debug.WriteLine($"Gemma: Saved current conversation ('{GetCurrentConversation()}') to prewarmed cache.");
|
|
}
|
|
|
|
public int CountTokens(string prompt)
|
|
{
|
|
if (_disposed)
|
|
throw new ObjectDisposedException(nameof(Gemma));
|
|
|
|
if (_context == IntPtr.Zero)
|
|
throw new GemmaException("Gemma context is invalid");
|
|
int count = GemmaCountTokens(_context, prompt);
|
|
return count;
|
|
}
|
|
|
|
public string Generate(string prompt, int maxOutputChars = 4096)
|
|
{
|
|
return Generate(prompt, null, maxOutputChars);
|
|
}
|
|
|
|
public string Generate(string prompt, TokenCallback callback, int maxOutputChars = 4096)
|
|
{
|
|
if (_disposed)
|
|
throw new ObjectDisposedException(nameof(Gemma));
|
|
|
|
if (_context == IntPtr.Zero)
|
|
throw new GemmaException("Gemma context is invalid");
|
|
|
|
var outputBuffer = new byte[maxOutputChars * 4]; // Allow for worst case UTF-8 size
|
|
GemmaTokenCallback nativeCallback = null;
|
|
|
|
// Track token count for debugging
|
|
int tokenCount = 0;
|
|
|
|
if (callback != null)
|
|
{
|
|
nativeCallback = (text, _) =>
|
|
{
|
|
tokenCount++;
|
|
// Log token for debugging
|
|
Debug.WriteLine($"Token {tokenCount}: '{text}'");
|
|
|
|
// Pass token to user callback
|
|
return callback(text);
|
|
};
|
|
_callbackHandle = GCHandle.Alloc(nativeCallback);
|
|
}
|
|
|
|
try
|
|
{
|
|
int length = GemmaGenerate(_context, prompt, outputBuffer, maxOutputChars,
|
|
nativeCallback, IntPtr.Zero);
|
|
|
|
if (length < 0)
|
|
throw new GemmaException("Generation failed");
|
|
|
|
Debug.WriteLine($"Generation complete: {tokenCount} tokens processed, result length: {length}");
|
|
|
|
// Convert the byte buffer to a string using UTF-8 encoding
|
|
string result = Encoding.UTF8.GetString(outputBuffer, 0, length);
|
|
return result;
|
|
}
|
|
finally
|
|
{
|
|
if (_callbackHandle.IsAllocated)
|
|
_callbackHandle.Free();
|
|
}
|
|
}
|
|
|
|
public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, int maxOutputChars = 4096)
|
|
{
|
|
// Pass width and height to the overloaded method
|
|
return GenerateMultimodal(prompt, imageData, imageWidth, imageHeight, null, maxOutputChars);
|
|
}
|
|
|
|
public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, TokenCallback callback, int maxOutputChars = 4096)
|
|
{
|
|
if (_disposed)
|
|
throw new ObjectDisposedException(nameof(Gemma));
|
|
|
|
if (_context == IntPtr.Zero)
|
|
throw new GemmaException("Gemma context is invalid");
|
|
|
|
if (imageData == null || imageData.Length == 0)
|
|
throw new ArgumentException("Image data cannot be null or empty", nameof(imageData));
|
|
|
|
if (imageWidth <= 0 || imageHeight <= 0)
|
|
throw new ArgumentException("Image dimensions must be positive");
|
|
|
|
if (imageData.Length < imageWidth * imageHeight * 3)
|
|
throw new ArgumentException("Image data array is too small for the specified dimensions");
|
|
|
|
var output = new StringBuilder(maxOutputChars);
|
|
GemmaTokenCallback nativeCallback = null;
|
|
|
|
if (callback != null)
|
|
{
|
|
nativeCallback = (text, _) => callback(text);
|
|
_callbackHandle = GCHandle.Alloc(nativeCallback);
|
|
}
|
|
|
|
// Pin the image data so it doesn't move during the native call
|
|
GCHandle imageHandle = GCHandle.Alloc(imageData, GCHandleType.Pinned);
|
|
|
|
try
|
|
{
|
|
IntPtr imagePtr = imageHandle.AddrOfPinnedObject();
|
|
|
|
// Pass image dimensions to the native call
|
|
int length = GemmaGenerateMultimodal(_context, prompt, imagePtr, imageWidth, imageHeight, output, maxOutputChars,
|
|
nativeCallback, IntPtr.Zero);
|
|
|
|
if (length < 0)
|
|
throw new GemmaException("Multimodal generation failed");
|
|
|
|
return output.ToString();
|
|
}
|
|
finally
|
|
{
|
|
imageHandle.Free();
|
|
|
|
if (_callbackHandle.IsAllocated)
|
|
_callbackHandle.Free();
|
|
}
|
|
}
|
|
|
|
public void Dispose()
|
|
{
|
|
if (!_disposed)
|
|
{
|
|
if (_context != IntPtr.Zero)
|
|
{
|
|
GemmaDestroy(_context);
|
|
_context = IntPtr.Zero;
|
|
}
|
|
if (_logCallbackHandle.IsAllocated)
|
|
_logCallbackHandle.Free();
|
|
_disposed = true;
|
|
}
|
|
}
|
|
|
|
~Gemma()
|
|
{
|
|
Dispose();
|
|
}
|
|
}
|
|
}
|