From c5058366dccc6b91895147649c1d6507f9b91d7f Mon Sep 17 00:00:00 2001 From: Han Yin Date: Thu, 26 Jun 2025 21:42:43 -0700 Subject: [PATCH] lib: hide the internal implementations, only expose a facade and interfaces --- .../java/com/example/llama/di/AppModule.kt | 4 +- .../llama/src/main/cpp/cpu_detector.cpp | 4 +- .../llama/src/main/cpp/llama-android.cpp | 20 +- .../llama/cpp/InferenceEngineLoader.kt | 179 ------ .../java/android/llama/cpp/KleidiLlama.kt | 22 + .../java/android/llama/cpp/TierDetection.kt | 27 + .../java/android/llama/cpp/gguf/FileType.kt | 59 ++ .../android/llama/cpp/gguf/GgufMetadata.kt | 57 -- .../llama/cpp/gguf/GgufMetadataReader.kt | 563 +---------------- .../cpp/internal/InferenceEngineFactory.kt | 13 + .../cpp/{ => internal}/InferenceEngineImpl.kt | 63 +- .../cpp/internal/InferenceEngineLoader.kt | 165 +++++ .../llama/cpp/internal/TierDetectionImpl.kt | 15 + .../internal/gguf/GgufMetadataReaderImpl.kt | 568 ++++++++++++++++++ 14 files changed, 917 insertions(+), 842 deletions(-) delete mode 100644 examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngineLoader.kt create mode 100644 examples/llama.android/llama/src/main/java/android/llama/cpp/KleidiLlama.kt create mode 100644 examples/llama.android/llama/src/main/java/android/llama/cpp/TierDetection.kt create mode 100644 examples/llama.android/llama/src/main/java/android/llama/cpp/gguf/FileType.kt create mode 100644 examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineFactory.kt rename examples/llama.android/llama/src/main/java/android/llama/cpp/{ => internal}/InferenceEngineImpl.kt (80%) create mode 100644 examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineLoader.kt create mode 100644 examples/llama.android/llama/src/main/java/android/llama/cpp/internal/TierDetectionImpl.kt create mode 100644 examples/llama.android/llama/src/main/java/android/llama/cpp/internal/gguf/GgufMetadataReaderImpl.kt diff --git a/examples/llama.android/app/src/main/java/com/example/llama/di/AppModule.kt b/examples/llama.android/app/src/main/java/com/example/llama/di/AppModule.kt index 3b58b6ca5d..658096617e 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/di/AppModule.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/di/AppModule.kt @@ -2,7 +2,7 @@ package com.example.llama.di import android.content.Context import android.llama.cpp.InferenceEngine -import android.llama.cpp.InferenceEngineLoader +import android.llama.cpp.KleidiLlama import android.llama.cpp.gguf.GgufMetadataReader import com.example.llama.data.local.AppDatabase import com.example.llama.data.remote.HuggingFaceApiService @@ -66,7 +66,7 @@ internal abstract class AppModule { return if (USE_STUB_ENGINE) { StubInferenceEngine() } else { - InferenceEngineLoader.createInstance(context) + KleidiLlama.createInferenceEngine(context) ?: throw InstantiationException("Cannot instantiate InferenceEngine!") } } diff --git a/examples/llama.android/llama/src/main/cpp/cpu_detector.cpp b/examples/llama.android/llama/src/main/cpp/cpu_detector.cpp index 2913036fa7..2b60fee829 100644 --- a/examples/llama.android/llama/src/main/cpp/cpu_detector.cpp +++ b/examples/llama.android/llama/src/main/cpp/cpu_detector.cpp @@ -12,7 +12,7 @@ static const Aarch64Info info = GetAarch64Info(); static const Aarch64Features features = info.features; extern "C" JNIEXPORT jint JNICALL -Java_android_llama_cpp_InferenceEngineLoader_getOptimalTier( +Java_android_llama_cpp_internal_InferenceEngineLoader_getOptimalTier( JNIEnv* env, jclass clazz) { int tier = 0; // Default to T0 (baseline) @@ -46,7 +46,7 @@ Java_android_llama_cpp_InferenceEngineLoader_getOptimalTier( // Optional: Keep a feature string function for debugging extern "C" JNIEXPORT jstring JNICALL -Java_android_llama_cpp_InferenceEngineLoader_getCpuFeaturesString( +Java_android_llama_cpp_internal_InferenceEngineLoader_getCpuFeaturesString( JNIEnv* env, jclass clazz) { std::string text; diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index e06150572a..3b8f526043 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -72,7 +72,7 @@ static void log_callback(ggml_log_level level, const char *fmt, void *data) { extern "C" JNIEXPORT void JNICALL -Java_android_llama_cpp_InferenceEngineImpl_init(JNIEnv *env, jobject /*unused*/) { +Java_android_llama_cpp_internal_InferenceEngineImpl_init(JNIEnv *env, jobject /*unused*/) { // Set llama log handler to Android llama_log_set(log_callback, nullptr); @@ -83,7 +83,7 @@ Java_android_llama_cpp_InferenceEngineImpl_init(JNIEnv *env, jobject /*unused*/) extern "C" JNIEXPORT jint JNICALL -Java_android_llama_cpp_InferenceEngineImpl_load(JNIEnv *env, jobject, jstring jmodel_path) { +Java_android_llama_cpp_internal_InferenceEngineImpl_load(JNIEnv *env, jobject, jstring jmodel_path) { llama_model_params model_params = llama_model_default_params(); const auto *model_path = env->GetStringUTFChars(jmodel_path, 0); @@ -137,7 +137,7 @@ static common_sampler *new_sampler(float temp) { extern "C" JNIEXPORT jint JNICALL -Java_android_llama_cpp_InferenceEngineImpl_prepare(JNIEnv * /*env*/, jobject /*unused*/) { +Java_android_llama_cpp_internal_InferenceEngineImpl_prepare(JNIEnv * /*env*/, jobject /*unused*/) { auto *context = init_context(g_model); if (!context) { return 1; } g_context = context; @@ -161,13 +161,13 @@ static std::string get_backend() { extern "C" JNIEXPORT jstring JNICALL -Java_android_llama_cpp_InferenceEngineImpl_systemInfo(JNIEnv *env, jobject /*unused*/) { +Java_android_llama_cpp_internal_InferenceEngineImpl_systemInfo(JNIEnv *env, jobject /*unused*/) { return env->NewStringUTF(llama_print_system_info()); } extern "C" JNIEXPORT jstring JNICALL -Java_android_llama_cpp_InferenceEngineImpl_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg, +Java_android_llama_cpp_internal_InferenceEngineImpl_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg, jint pl, jint nr) { auto *context = init_context(g_model, pp); if (!context) { @@ -377,7 +377,7 @@ static int decode_tokens_in_batches( extern "C" JNIEXPORT jint JNICALL -Java_android_llama_cpp_InferenceEngineImpl_processSystemPrompt( +Java_android_llama_cpp_internal_InferenceEngineImpl_processSystemPrompt( JNIEnv *env, jobject /*unused*/, jstring jsystem_prompt @@ -426,7 +426,7 @@ Java_android_llama_cpp_InferenceEngineImpl_processSystemPrompt( extern "C" JNIEXPORT jint JNICALL -Java_android_llama_cpp_InferenceEngineImpl_processUserPrompt( +Java_android_llama_cpp_internal_InferenceEngineImpl_processUserPrompt( JNIEnv *env, jobject /*unused*/, jstring juser_prompt, @@ -510,7 +510,7 @@ static bool is_valid_utf8(const char *string) { extern "C" JNIEXPORT jstring JNICALL -Java_android_llama_cpp_InferenceEngineImpl_generateNextToken( +Java_android_llama_cpp_internal_InferenceEngineImpl_generateNextToken( JNIEnv *env, jobject /*unused*/ ) { @@ -570,7 +570,7 @@ Java_android_llama_cpp_InferenceEngineImpl_generateNextToken( extern "C" JNIEXPORT void JNICALL -Java_android_llama_cpp_InferenceEngineImpl_unload(JNIEnv * /*unused*/, jobject /*unused*/) { +Java_android_llama_cpp_internal_InferenceEngineImpl_unload(JNIEnv * /*unused*/, jobject /*unused*/) { // Reset long-term & short-term states reset_long_term_states(); reset_short_term_states(); @@ -585,6 +585,6 @@ Java_android_llama_cpp_InferenceEngineImpl_unload(JNIEnv * /*unused*/, jobject / extern "C" JNIEXPORT void JNICALL -Java_android_llama_cpp_InferenceEngineImpl_shutdown(JNIEnv *env, jobject /*unused*/) { +Java_android_llama_cpp_internal_InferenceEngineImpl_shutdown(JNIEnv *env, jobject /*unused*/) { llama_backend_free(); } diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngineLoader.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngineLoader.kt deleted file mode 100644 index 89412ac068..0000000000 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngineLoader.kt +++ /dev/null @@ -1,179 +0,0 @@ -package android.llama.cpp - -import android.content.Context -import android.util.Log -import androidx.datastore.core.DataStore -import androidx.datastore.preferences.core.Preferences -import androidx.datastore.preferences.core.edit -import androidx.datastore.preferences.core.intPreferencesKey -import androidx.datastore.preferences.preferencesDataStore -import kotlinx.coroutines.flow.first -import kotlinx.coroutines.runBlocking - -enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) { - T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"), - T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"), - T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"), - T3(3, "llama_android_t3", "ARMv9-a with DotProd + I8MM + SVE/SVE2"); - // TODO-han.yin: implement T4 once obtaining an Android device with SME! - - companion object { - fun fromRawValue(value: Int): LLamaTier? { - return entries.find { it.rawValue == value } - } - - fun getMaxSupportedTier(): LLamaTier = T3 - } -} - -class InferenceEngineLoader private constructor() { - - companion object { - private val TAG = InferenceEngineLoader::class.simpleName - - // CPU feature detection preferences - private const val DATASTORE_CPU_DETECTION = "llama_cpu_detection" - private val Context.llamaTierDataStore: DataStore - by preferencesDataStore(name = DATASTORE_CPU_DETECTION) - - private val DETECTION_VERSION = intPreferencesKey("detection_version") - private val DETECTED_TIER = intPreferencesKey("detected_tier") - - // Constants - private const val DATASTORE_VERSION = 1 - - @JvmStatic - private external fun getOptimalTier(): Int - - @JvmStatic - private external fun getCpuFeaturesString(): String - - private var _cachedInstance: InferenceEngineImpl? = null - private var _detectedTier: LLamaTier? = null - val detectedTier: LLamaTier? get() = _detectedTier - - /** - * Factory method to get a configured [InferenceEngineImpl] instance. - * Handles tier detection, caching, and library loading automatically. - */ - @Synchronized - fun createInstance(context: Context): InferenceEngine? { - // Return cached instance if available - _cachedInstance?.let { return it } - - return runBlocking { - try { - // Obtain the optimal tier from cache if available - val tier = getOrDetectOptimalTier(context) ?: run { - Log.e(TAG, "Failed to determine optimal tier") - return@runBlocking null - } - _detectedTier = tier - Log.i(TAG, "Using tier: ${tier.name} (${tier.description})") - - // Create and cache the inference engine instance - val instance = InferenceEngineImpl.createWithTier(tier) ?: run { - Log.e(TAG, "Failed to instantiate InferenceEngineImpl") - return@runBlocking null - } - _cachedInstance = instance - Log.i(TAG, "Successfully created InferenceEngineImpl instance with ${tier.name}") - - instance - - } catch (e: Exception) { - Log.e(TAG, "Error creating InferenceEngineImpl instance", e) - null - } - } - } - - /** - * Clear cached detection results (for testing/debugging) - */ - fun clearCache(context: Context) { - runBlocking { context.llamaTierDataStore.edit { it.clear() } } - _cachedInstance = null - _detectedTier = null - Log.i(TAG, "Cleared detection results and cached instance") - } - - /** - * Get optimal tier from cache or detect it fresh - */ - private suspend fun getOrDetectOptimalTier(context: Context): LLamaTier? { - val preferences = context.llamaTierDataStore.data.first() - - // Check if we have a cached result with the current detection version - val cachedVersion = preferences[DETECTION_VERSION] ?: -1 - val cachedTierValue = preferences[DETECTED_TIER] ?: -1 - if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) { - val cachedTier = LLamaTier.fromRawValue(cachedTierValue) - if (cachedTier != null) { - Log.i(TAG, "Using cached tier detection: ${cachedTier.name}") - return cachedTier - } - } - - // No valid cache, detect fresh - Log.i(TAG, "Performing fresh tier detection") - return detectAndCacheOptimalTier(context) - } - - /** - * Detect optimal tier and save to cache - */ - private suspend fun detectAndCacheOptimalTier(context: Context): LLamaTier? { - try { - // Load CPU detection library - System.loadLibrary("llama_cpu_detector") - Log.i(TAG, "CPU feature detector loaded successfully") - - // Detect optimal tier - val tierValue = getOptimalTier() - val features = getCpuFeaturesString() - Log.i(TAG, "Raw tier $tierValue w/ CPU features: $features") - - // Convert to enum and validate - val tier = LLamaTier.fromRawValue(tierValue) ?: run { - Log.w(TAG, "Invalid tier value $tierValue") - return null - } - - // Ensure we don't exceed maximum supported tier - val finalTier = if (tier.rawValue > LLamaTier.getMaxSupportedTier().rawValue) { - Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${LLamaTier.getMaxSupportedTier().name}") - LLamaTier.getMaxSupportedTier() - } else { - tier - } - - // Cache the result - context.llamaTierDataStore.edit { - it[DETECTED_TIER] = finalTier.rawValue - it[DETECTION_VERSION] = DATASTORE_VERSION - } - - Log.i(TAG, "Detected and cached optimal tier: ${finalTier.name}") - return finalTier - - } catch (e: UnsatisfiedLinkError) { - Log.e(TAG, "Failed to load CPU detection library", e) - - // Fallback to T0 and cache it - val fallbackTier = LLamaTier.T0 - context.llamaTierDataStore.edit { - it[DETECTED_TIER] = fallbackTier.rawValue - it[DETECTION_VERSION] = DATASTORE_VERSION - } - - Log.i(TAG, "Using fallback tier: ${fallbackTier.name}") - return fallbackTier - - } catch (e: Exception) { - Log.e(TAG, "Unexpected error during tier detection", e) - return null - } - } - } -} diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/KleidiLlama.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/KleidiLlama.kt new file mode 100644 index 0000000000..8a0b264ee4 --- /dev/null +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/KleidiLlama.kt @@ -0,0 +1,22 @@ +package android.llama.cpp + +import android.content.Context +import android.llama.cpp.internal.InferenceEngineFactory + +/** + * Main entry point for the Llama Android library. + * This is the only class that should be used by library consumers. + */ +object KleidiLlama { + /** + * Create an inference engine instance with automatic tier detection. + */ + fun createInferenceEngine(context: Context) = + InferenceEngineFactory.createInstance(context) + + /** + * Get tier detection information for debugging/settings. + */ + fun getTierDetection(context: Context) = + InferenceEngineFactory.getTierDetection(context) +} diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/TierDetection.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/TierDetection.kt new file mode 100644 index 0000000000..1d5c6566ea --- /dev/null +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/TierDetection.kt @@ -0,0 +1,27 @@ +package android.llama.cpp + +/** + * Public interface for [LLamaTier] detection information. + */ +interface TierDetection { + val detectedTier: LLamaTier? + fun clearCache() +} + +/** + * ARM optimization tiers supported by the Kleidi-Llama library. + * Higher tiers provide better performance on supported hardware. + */ +enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) { + T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"), + T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"), + T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"), + T3(3, "llama_android_t3", "ARMv9-a with DotProd + I8MM + SVE/SVE2"); + // TODO-han.yin: implement T4 once obtaining an Android device with SME! + + companion object { + fun fromRawValue(value: Int): LLamaTier? = entries.find { it.rawValue == value } + + val maxSupportedTier = T3 + } +} diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/gguf/FileType.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/gguf/FileType.kt new file mode 100644 index 0000000000..ea694814ac --- /dev/null +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/gguf/FileType.kt @@ -0,0 +1,59 @@ +package android.llama.cpp.gguf + + +/** + * Numerical codes used by `general.file_type` (see llama.cpp repo's `constants.py`). + * The `label` matches what llama‑cli prints. + */ +enum class FileType(val code: Int, val label: String) { + ALL_F32(0, "all F32"), + MOSTLY_F16(1, "F16"), + MOSTLY_Q4_0(2, "Q4_0"), + MOSTLY_Q4_1(3, "Q4_1"), + // 4 removed + MOSTLY_Q8_0(7, "Q8_0"), + MOSTLY_Q5_0(8, "Q5_0"), + MOSTLY_Q5_1(9, "Q5_1"), + + /* K‑quants ------------------------------------------------------------ */ + MOSTLY_Q2_K (10, "Q2_K - Medium"), + MOSTLY_Q3_K_S (11, "Q3_K - Small"), + MOSTLY_Q3_K_M (12, "Q3_K - Medium"), + MOSTLY_Q3_K_L (13, "Q3_K - Large"), + MOSTLY_Q4_K_S (14, "Q4_K - Small"), + MOSTLY_Q4_K_M (15, "Q4_K - Medium"), + MOSTLY_Q5_K_S (16, "Q5_K - Small"), + MOSTLY_Q5_K_M (17, "Q5_K - Medium"), + MOSTLY_Q6_K (18, "Q6_K"), + + /* IQ quants ----------------------------------------------------------- */ + MOSTLY_IQ2_XXS (19, "IQ2_XXS - 2.06 bpw"), + MOSTLY_IQ2_XS (20, "IQ2_XS - 2.31 bpw"), + MOSTLY_Q2_K_S (21, "Q2_K - Small"), + MOSTLY_IQ3_XS (22, "IQ3_XS - 3.30 bpw"), + MOSTLY_IQ3_XXS (23, "IQ3_XXS - 3.06 bpw"), + MOSTLY_IQ1_S (24, "IQ1_S - 1.56 bpw"), + MOSTLY_IQ4_NL (25, "IQ4_NL - 4.5 bpw"), + MOSTLY_IQ3_S (26, "IQ3_S - 3.44 bpw"), + MOSTLY_IQ3_M (27, "IQ3_M - 3.66 bpw"), + MOSTLY_IQ2_S (28, "IQ2_S - 2.50 bpw"), + MOSTLY_IQ2_M (29, "IQ2_M - 2.70 bpw"), + MOSTLY_IQ4_XS (30, "IQ4_XS - 4.25 bpw"), + MOSTLY_IQ1_M (31, "IQ1_M - 1.75 bpw"), + + /* BF16 & Ternary ------------------------------------------------------ */ + MOSTLY_BF16 (32, "BF16"), + MOSTLY_TQ1_0 (36, "TQ1_0 - 1.69 bpw ternary"), + MOSTLY_TQ2_0 (37, "TQ2_0 - 2.06 bpw ternary"), + + /* Special flag -------------------------------------------------------- */ + GUESSED(1024, "(guessed)"), + + UNKNOWN(-1, "unknown"); + + companion object { + private val map = entries.associateBy(FileType::code) + + fun fromCode(code: Int?): FileType = map[code] ?: UNKNOWN + } +} diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/gguf/GgufMetadata.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/gguf/GgufMetadata.kt index 8a51e35e79..1e48773037 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/gguf/GgufMetadata.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/gguf/GgufMetadata.kt @@ -130,60 +130,3 @@ data class GgufMetadata( val usedCount: Int? = null, ) } - -/** - * Numerical codes used by `general.file_type` (see llama.cpp repo's `constants.py`). - * The `label` matches what llama‑cli prints. - */ -enum class FileType(val code: Int, val label: String) { - ALL_F32(0, "all F32"), - MOSTLY_F16(1, "F16"), - MOSTLY_Q4_0(2, "Q4_0"), - MOSTLY_Q4_1(3, "Q4_1"), - // 4 removed - MOSTLY_Q8_0(7, "Q8_0"), - MOSTLY_Q5_0(8, "Q5_0"), - MOSTLY_Q5_1(9, "Q5_1"), - - /* K‑quants ------------------------------------------------------------ */ - MOSTLY_Q2_K (10, "Q2_K - Medium"), - MOSTLY_Q3_K_S (11, "Q3_K - Small"), - MOSTLY_Q3_K_M (12, "Q3_K - Medium"), - MOSTLY_Q3_K_L (13, "Q3_K - Large"), - MOSTLY_Q4_K_S (14, "Q4_K - Small"), - MOSTLY_Q4_K_M (15, "Q4_K - Medium"), - MOSTLY_Q5_K_S (16, "Q5_K - Small"), - MOSTLY_Q5_K_M (17, "Q5_K - Medium"), - MOSTLY_Q6_K (18, "Q6_K"), - - /* IQ quants ----------------------------------------------------------- */ - MOSTLY_IQ2_XXS (19, "IQ2_XXS - 2.06 bpw"), - MOSTLY_IQ2_XS (20, "IQ2_XS - 2.31 bpw"), - MOSTLY_Q2_K_S (21, "Q2_K - Small"), - MOSTLY_IQ3_XS (22, "IQ3_XS - 3.30 bpw"), - MOSTLY_IQ3_XXS (23, "IQ3_XXS - 3.06 bpw"), - MOSTLY_IQ1_S (24, "IQ1_S - 1.56 bpw"), - MOSTLY_IQ4_NL (25, "IQ4_NL - 4.5 bpw"), - MOSTLY_IQ3_S (26, "IQ3_S - 3.44 bpw"), - MOSTLY_IQ3_M (27, "IQ3_M - 3.66 bpw"), - MOSTLY_IQ2_S (28, "IQ2_S - 2.50 bpw"), - MOSTLY_IQ2_M (29, "IQ2_M - 2.70 bpw"), - MOSTLY_IQ4_XS (30, "IQ4_XS - 4.25 bpw"), - MOSTLY_IQ1_M (31, "IQ1_M - 1.75 bpw"), - - /* BF16 & Ternary ------------------------------------------------------ */ - MOSTLY_BF16 (32, "BF16"), - MOSTLY_TQ1_0 (36, "TQ1_0 - 1.69 bpw ternary"), - MOSTLY_TQ2_0 (37, "TQ2_0 - 2.06 bpw ternary"), - - /* Special flag -------------------------------------------------------- */ - GUESSED(1024, "(guessed)"), - - UNKNOWN(-1, "unknown"); - - companion object { - private val map = entries.associateBy(FileType::code) - - fun fromCode(code: Int?): FileType = map[code] ?: UNKNOWN - } -} diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/gguf/GgufMetadataReader.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/gguf/GgufMetadataReader.kt index 777cb8ec47..bfc590fff5 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/gguf/GgufMetadataReader.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/gguf/GgufMetadataReader.kt @@ -1,8 +1,7 @@ package android.llama.cpp.gguf -import java.io.File +import android.llama.cpp.internal.gguf.GgufMetadataReaderImpl import java.io.IOException -import java.io.InputStream /** * Interface for reading GGUF metadata from model files. @@ -51,563 +50,3 @@ interface GgufMetadataReader { ) } } - -/** - * Utility class to read GGUF model files and extract metadata key-value pairs. - * This parser reads the header and metadata of a GGUF v3 file (little-endian) and skips tensor data. - */ -private class GgufMetadataReaderImpl( - private val skipKeys: Set, - private val arraySummariseThreshold: Int, -) : GgufMetadataReader { - companion object { - private const val ARCH_LLAMA = "llama" - } - - /** Enum corresponding to GGUF metadata value types (for convenience and array element typing). */ - enum class MetadataType(val code: Int) { - UINT8(0), INT8(1), UINT16(2), INT16(3), - UINT32(4), INT32(5), FLOAT32(6), BOOL(7), - STRING(8), ARRAY(9), UINT64(10), INT64(11), FLOAT64(12); - companion object { - private val codeMap = values().associateBy(MetadataType::code) - fun fromCode(code: Int): MetadataType = codeMap[code] - ?: throw IOException("Unknown metadata value type code: $code") - } - } - - /** Sealed class hierarchy for metadata values, providing type-safe representations for each GGUF metadata type. */ - sealed class MetadataValue { - data class UInt8(val value: UByte) : MetadataValue() // 0: 8-bit unsigned int - data class Int8(val value: Byte) : MetadataValue() // 1: 8-bit signed int - data class UInt16(val value: UShort) : MetadataValue() // 2: 16-bit unsigned int (little-endian) - data class Int16(val value: Short) : MetadataValue() // 3: 16-bit signed int (little-endian) - data class UInt32(val value: UInt) : MetadataValue() // 4: 32-bit unsigned int (little-endian) - data class Int32(val value: Int) : MetadataValue() // 5: 32-bit signed int (little-endian) - data class Float32(val value: Float) : MetadataValue() // 6: 32-bit IEEE754 float - data class Bool(val value: Boolean) : MetadataValue() // 7: Boolean (1-byte, 0=false, 1=true) - data class StringVal(val value: String) : MetadataValue() // 8: UTF-8 string (length-prefixed) - data class ArrayVal(val elementType: MetadataType, val elements: List) : MetadataValue() - data class UInt64(val value: ULong) : MetadataValue() // 10: 64-bit unsigned int (little-endian) - data class Int64(val value: Long) : MetadataValue() // 11: 64-bit signed int (little-endian) - data class Float64(val value: Double) : MetadataValue() // 12: 64-bit IEEE754 double - } - - /* Convert MetadataValue to plain Kotlin primitives for allMetadata map */ - private fun MetadataValue.toPrimitive(): Any = when (this) { - is MetadataValue.UInt8 -> value - is MetadataValue.Int8 -> value - is MetadataValue.UInt16 -> value - is MetadataValue.Int16 -> value - is MetadataValue.UInt32 -> value - is MetadataValue.Int32 -> value - is MetadataValue.Float32 -> value - is MetadataValue.Bool -> value - is MetadataValue.StringVal -> value - is MetadataValue.UInt64 -> value - is MetadataValue.Int64 -> value - is MetadataValue.Float64 -> value - is MetadataValue.ArrayVal -> elements.map { it.toPrimitive() } - } - - /** - * High‑level entry point: parses a `.gguf` file on disk and returns the fully - * populated [GgufMetadata] tree. - * - * Steps performed internally: - * 1. Reads and validates the 8‑byte header (`"GGUF"` magic + version). - * 2. Streams through the key‑value section, skipping large blobs if the key - * appears in [skipKeys] or if an array exceeds [arraySummariseThreshold]. - * 3. Converts the resulting raw map into strongly‑typed sub‑structures - * (basic info, tokenizer, rope, etc.). - * - * The method is STREAMING‑ONLY: tensors are never mapped or loaded into - * memory, so even multi‑GB model files can be processed in < 50 ms. - * - * @param path Absolute or relative filesystem path to a `.gguf` file. - * @return A [GgufMetadata] instance containing all recognised metadata plus - * an `allMetadata` map with any keys that were not given a dedicated - * field. - * @throws IOException if the file is not GGUF, the version is unsupported, - * or the metadata block is truncated / corrupt. - */ - override suspend fun readStructuredMetadata(path: String): GgufMetadata { - File(path).inputStream().buffered().use { input -> - // ── 1. header ────────────────────────────────────────────────────────── - // throws on mismatch - val version = ensureMagicAndVersion(input) - val tensorCount = readLittleLong(input) - val kvCount = readLittleLong(input) - - // ── 2. metadata map (reuse our raw parser, but we need access to the stream) ── - val meta = readMetaMap(input, kvCount) // - - // ── 3. build structured object ──────────────────────────────────────── - return buildStructured(meta, version, tensorCount, kvCount) - } - } - - /** Reads the 4‑byte magic + 4‑byte version; throws if magic ≠ "GGUF". */ - private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion { - val magic = ByteArray(4) - if (input.read(magic) != 4) throw IOException("File too short (no magic)") - if (!magic.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46))) // "GGUF" - throw IOException("Not a GGUF file (bad magic)") - return GgufMetadata.GgufVersion.fromCode(readLEUInt32(input)) - } - - /** - * Read an unsigned 32‑bit little‑endian integer. - * - * @throws IOException if fewer than four bytes are available. - */ - private fun readLEUInt32(input: InputStream): Int { - val b0 = input.read(); val b1 = input.read(); val b2 = input.read(); val b3 = input.read() - if (b3 == -1) throw IOException("Unexpected EOF while reading UInt32") - return (b3 and 0xFF shl 24) or - (b2 and 0xFF shl 16) or - (b1 and 0xFF shl 8) or - (b0 and 0xFF) - } - - /** - * Low‑level helper that reads the entire “key-value” section from the current - * stream position. - * - * @param input Open stream positioned JUST AFTER the header. - * @param kvCnt Number of key‑value pairs (taken from the header). - * @return Mutable map with one [MetadataValue] for every key that is NOT skipped. - * - * The function honours [skipKeys] and [arraySummariseThreshold] by invoking - * [skipValue] or [parseValue] accordingly. - */ - private fun readMetaMap(input: InputStream, kvCnt: Long): Map { - val map = mutableMapOf() - repeat(kvCnt.toInt()) { - val key = readString(input) - val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4))) - if (key in skipKeys) { - skipValue(input, valueT) - } else { - map[key] = parseValue(input, valueT) - } - } - return map - } - - /** - * Converts a flat [Map]<[String], [MetadataValue]> into the strongly‑typed - * [GgufMetadata] tree used by the rest of the app. - * - * Only the keys listed in the spec are copied into dedicated data classes; - * everything else is preserved in `GgufMetadata.allMetadata`. - * - * @param m Raw key/value map. - * @param version GGUF file‑format version (enum). - * @param tensorCnt Number of tensors (from the header). - * @param kvCnt Total metadata pair count (from the header). - */ - private fun buildStructured( - m: Map, - version: GgufMetadata.GgufVersion, - tensorCnt: Long, - kvCnt: Long - ): GgufMetadata { - // ---------- helpers ---------- - fun String.str() = (m[this] as? MetadataValue.StringVal)?.value - fun String.bool() = (m[this] as? MetadataValue.Bool)?.value - fun String.i32() = (m[this] as? MetadataValue.Int32)?.value - fun String.u32() = (m[this] as? MetadataValue.UInt32)?.value?.toInt() - fun String.f32() = (m[this] as? MetadataValue.Float32)?.value - fun String.f64() = (m[this] as? MetadataValue.Float64)?.value?.toFloat() - fun String.strList(): List? = - (m[this] as? MetadataValue.ArrayVal) - ?.elements - ?.mapNotNull { (it as? MetadataValue.StringVal)?.value } - - val arch = "general.architecture".str() ?: ARCH_LLAMA - - // -------------- populate sections ---------------- - val basic = GgufMetadata.BasicInfo( - uuid = "general.uuid".str(), - name = "general.basename".str(), - nameLabel = "general.name".str(), - sizeLabel = "general.size_label".str() - ) - - val author = GgufMetadata.AuthorInfo( - organization = "general.organization".str(), - author = "general.author".str(), - doi = "general.doi".str(), - url = "general.url".str(), - repoUrl = "general.repo_url".str(), - license = "general.license".str(), - licenseLink = "general.license.link".str() - ).takeUnless { - organization == null && author == null && doi == null && - url == null && repoUrl == null && license == null && licenseLink == null - } - - val additional = GgufMetadata.AdditionalInfo( - type = "general.type".str(), - description = "general.description".str(), - tags = "general.tags".strList(), - languages = "general.languages".strList() - ).takeUnless { - type == null && description == null && tags == null && languages == null - } - - val architectureInfo = GgufMetadata.ArchitectureInfo( - architecture = arch, - fileType = "general.file_type".u32(), - vocabSize = "$arch.vocab_size".u32(), - finetune = "general.finetune".str(), - quantizationVersion = "general.quantization_version".u32() - ).takeUnless { fileType == null && vocabSize == null && finetune == null && quantizationVersion == null } - - val baseModels = buildList { - val n = "general.base_model.count".u32() ?: 0 - for (i in 0 until n) { - fun k(s: String) = "general.base_model.$i.$s" - add( - GgufMetadata.BaseModelInfo( - name = k("name").str(), - author = k("author").str(), - version = k("version").str(), - organization = k("organization").str(), - url = k("url").str(), - doi = k("doi").str(), - uuid = k("uuid").str(), - repoUrl = k("repo_url").str(), - ) - ) - } - }.takeIf { it.isNotEmpty() } - - val tokenizer = GgufMetadata.TokenizerInfo( - model = "tokenizer.ggml.model".str(), - bosTokenId = "tokenizer.ggml.bos_token_id".u32(), - eosTokenId = "tokenizer.ggml.eos_token_id".u32(), - unknownTokenId = "tokenizer.ggml.unknown_token_id".u32(), - paddingTokenId = "tokenizer.ggml.padding_token_id".u32(), - addBosToken = "tokenizer.ggml.add_bos_token".bool(), - addEosToken = "tokenizer.ggml.add_eos_token".bool(), - chatTemplate = "tokenizer.chat_template".str() - ).takeUnless { model == null && bosTokenId == null && eosTokenId == null && - unknownTokenId == null && paddingTokenId == null && - addBosToken == null && addEosToken == null && chatTemplate == null - } - - val dimensions = GgufMetadata.DimensionsInfo( - contextLength = "$arch.context_length".u32(), - embeddingSize = "$arch.embedding_length".u32(), - blockCount = "$arch.block_count".u32(), - feedForwardSize = "$arch.feed_forward_length".u32() - ).takeUnless { contextLength == null && embeddingSize == null && blockCount == null && feedForwardSize == null } - - val attention = GgufMetadata.AttentionInfo( - headCount = "$arch.attention.head_count".u32(), - headCountKv = "$arch.attention.head_count_kv".u32(), - keyLength = "$arch.attention.key_length".u32(), - valueLength = "$arch.attention.value_length".u32(), - layerNormEpsilon = "$arch.attention.layer_norm_epsilon".f32(), - layerNormRmsEpsilon = "$arch.attention.layer_norm_rms_epsilon".f32(), - ).takeUnless { headCount == null && headCountKv == null && keyLength == null && valueLength == null && - layerNormEpsilon == null && layerNormRmsEpsilon == null - } - - val rope = GgufMetadata.RopeInfo( - frequencyBase = "$arch.rope.freq_base".f32(), - dimensionCount = "$arch.rope.dimension_count".u32(), - scalingType = "$arch.rope.scaling.type".str(), - scalingFactor = "$arch.rope.scaling.factor".f32(), - attnFactor = "$arch.rope.scaling.attn_factor".f32(), - originalContextLength = "$arch.rope.scaling.original_context_length".u32(), - finetuned = "$arch.rope.scaling.finetuned".bool() - ).takeUnless { frequencyBase == null && dimensionCount == null && - scalingType == null && scalingFactor == null && attnFactor == null && - originalContextLength == null && finetuned == null - } - - val experts = GgufMetadata.ExpertsInfo( - count = "$arch.expert_count".u32(), - usedCount = "$arch.expert_used_count".u32() - ).takeUnless { count == null && usedCount == null } - - return GgufMetadata( - version = version, - tensorCount = tensorCnt, - kvCount = kvCnt, - basic = basic, - author = author, - additional = additional, - architecture = architectureInfo, - baseModels = baseModels, - tokenizer = tokenizer, - dimensions = dimensions, - attention = attention, - rope = rope, - experts = experts - ) - } - - /** - * Recursively parses a metadata value of the given type from the input stream. - * @param input The input stream positioned at the start of the value. - * @param type The metadata value type to parse. - */ - private fun parseValue(input: InputStream, type: MetadataType): MetadataValue = when (type) { - MetadataType.UINT8 -> { - // 1-byte unsigned integer - val byteVal = input.read() - if (byteVal == -1) throw IOException("Unexpected EOF while reading uint8 value.") - MetadataValue.UInt8(byteVal.toUByte()) - } - MetadataType.INT8 -> { - // 1-byte signed integer - val byteVal = input.read() - if (byteVal == -1) throw IOException("Unexpected EOF while reading int8 value.") - MetadataValue.Int8(byteVal.toByte()) - } - MetadataType.UINT16 -> { - // 2-byte unsigned integer (little-endian) - val bytes = ByteArray(2) - if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading uint16 value.") - // Combine two bytes (little-endian) into an unsigned 16-bit value - val u16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF) - MetadataValue.UInt16(u16.toUShort()) - } - MetadataType.INT16 -> { - // 2-byte signed integer (little-endian) - val bytes = ByteArray(2) - if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading int16 value.") - // Combine to 16-bit and interpret as signed - val i16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF) - MetadataValue.Int16(i16.toShort()) - } - MetadataType.UINT32 -> { - // 4-byte unsigned integer (little-endian) - val bytes = ByteArray(4) - if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading uint32 value.") - // Combine four bytes into a 32-bit value (as Long to avoid overflow), then convert to UInt - val u32 = (bytes[3].toLong() and 0xFFL shl 24) or - (bytes[2].toLong() and 0xFFL shl 16) or - (bytes[1].toLong() and 0xFFL shl 8) or - (bytes[0].toLong() and 0xFFL) - MetadataValue.UInt32(u32.toUInt()) - } - MetadataType.INT32 -> { - // 4-byte signed integer (little-endian) - val bytes = ByteArray(4) - if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading int32 value.") - // Combine four bytes into a 32-bit signed int - val i32 = (bytes[3].toInt() and 0xFF shl 24) or - (bytes[2].toInt() and 0xFF shl 16) or - (bytes[1].toInt() and 0xFF shl 8) or - (bytes[0].toInt() and 0xFF) - MetadataValue.Int32(i32) - } - MetadataType.FLOAT32 -> { - // 4-byte IEEE 754 float (little-endian) - val bytes = ByteArray(4) - if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading float32 value.") - // Assemble 4 bytes into a 32-bit int bit-pattern, then convert to Float - val bits = (bytes[3].toInt() and 0xFF shl 24) or - (bytes[2].toInt() and 0xFF shl 16) or - (bytes[1].toInt() and 0xFF shl 8) or - (bytes[0].toInt() and 0xFF) - val floatVal = Float.fromBits(bits) - MetadataValue.Float32(floatVal) - } - MetadataType.BOOL -> { - // 1-byte boolean (0 = false, 1 = true) - val byteVal = input.read() - if (byteVal == -1) throw IOException("Unexpected EOF while reading boolean value.") - if (byteVal != 0 && byteVal != 1) { - throw IOException("Invalid boolean value: $byteVal (must be 0 or 1).") - } - MetadataValue.Bool(byteVal != 0) - } - MetadataType.STRING -> { - // UTF-8 string (length-prefixed with 8-byte length) - val str = readString(input) - MetadataValue.StringVal(str) - } - MetadataType.ARRAY -> { - val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4))) - val len = readLittleLong(input) - val count = len.toInt() - - if (arraySummariseThreshold >= 0 && count > arraySummariseThreshold) { - // fast‑forward without allocation - repeat(count) { skipValue(input, elemType) } - MetadataValue.StringVal("Array($elemType, $count items) /* summarised */") - } else { - val list = ArrayList(count) - repeat(count) { list += parseValue(input, elemType) } - MetadataValue.ArrayVal(elemType, list) - } - } - MetadataType.UINT64 -> { - // 8-byte unsigned integer (little-endian) - val bytes = ByteArray(8) - if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading uint64 value.") - // Combine 8 bytes into an unsigned 64-bit (ULong). Use ULong for full 0 to 2^64-1 range. - val u64 = (bytes[7].toULong() and 0xFFuL shl 56) or - (bytes[6].toULong() and 0xFFuL shl 48) or - (bytes[5].toULong() and 0xFFuL shl 40) or - (bytes[4].toULong() and 0xFFuL shl 32) or - (bytes[3].toULong() and 0xFFuL shl 24) or - (bytes[2].toULong() and 0xFFuL shl 16) or - (bytes[1].toULong() and 0xFFuL shl 8) or - (bytes[0].toULong() and 0xFFuL) - MetadataValue.UInt64(u64) - } - MetadataType.INT64 -> { - // 8-byte signed integer (little-endian) - val bytes = ByteArray(8) - if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading int64 value.") - // Combine 8 bytes into a signed 64-bit value (Long) - val i64 = (bytes[7].toLong() and 0xFFL shl 56) or - (bytes[6].toLong() and 0xFFL shl 48) or - (bytes[5].toLong() and 0xFFL shl 40) or - (bytes[4].toLong() and 0xFFL shl 32) or - (bytes[3].toLong() and 0xFFL shl 24) or - (bytes[2].toLong() and 0xFFL shl 16) or - (bytes[1].toLong() and 0xFFL shl 8) or - (bytes[0].toLong() and 0xFFL) - MetadataValue.Int64(i64) - } - MetadataType.FLOAT64 -> { - // 8-byte IEEE 754 double (little-endian) - val bytes = ByteArray(8) - if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading float64 value.") - // Assemble 8 bytes into a 64-bit bit-pattern, then convert to Double - val bits = (bytes[7].toLong() and 0xFFL shl 56) or - (bytes[6].toLong() and 0xFFL shl 48) or - (bytes[5].toLong() and 0xFFL shl 40) or - (bytes[4].toLong() and 0xFFL shl 32) or - (bytes[3].toLong() and 0xFFL shl 24) or - (bytes[2].toLong() and 0xFFL shl 16) or - (bytes[1].toLong() and 0xFFL shl 8) or - (bytes[0].toLong() and 0xFFL) - val doubleVal = Double.fromBits(bits) - MetadataValue.Float64(doubleVal) - } - } - - - private fun T?.takeUnless(check: T.() -> Boolean): T? = - this?.takeIf { !it.check() } - - /** Helper: Skip a value in the stream without storing it (still maintains pointer). */ - private fun skipValue(input: InputStream, type: MetadataType) { - when (type) { - MetadataType.UINT8, MetadataType.INT8, MetadataType.BOOL -> input.skipFully(1) - MetadataType.UINT16, MetadataType.INT16 -> input.skipFully(2) - MetadataType.UINT32, MetadataType.INT32, MetadataType.FLOAT32 -> input.skipFully(4) - MetadataType.UINT64, MetadataType.INT64, MetadataType.FLOAT64 -> input.skipFully(8) - MetadataType.STRING -> { - val len = readLittleLong(input); input.skipFully(len) - } - MetadataType.ARRAY -> { - val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4))) - val len = readLittleLong(input) - repeat(len.toInt()) { skipValue(input, elemType) } // recursive skip - } - } - } - - /** Helper: Read an 8-byte little-endian unsigned value and return it as a signed Long (assuming it fits in 63 bits). */ - private fun readLittleLong(input: InputStream): Long { - val bytes = ByteArray(8) - input.readFully(bytes) - - // Combine 8 bytes into a 64-bit value (Little Endian). - // Note: If the value exceeds Long.MAX_VALUE (bit 63 is 1), this will produce a negative Long (two's complement). - // In our context (lengths/counts), such extremely large values are not expected. - return (bytes[7].toLong() and 0xFFL shl 56) or - (bytes[6].toLong() and 0xFFL shl 48) or - (bytes[5].toLong() and 0xFFL shl 40) or - (bytes[4].toLong() and 0xFFL shl 32) or - (bytes[3].toLong() and 0xFFL shl 24) or - (bytes[2].toLong() and 0xFFL shl 16) or - (bytes[1].toLong() and 0xFFL shl 8) or - (bytes[0].toLong() and 0xFFL) - } - - /** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */ - private fun readString(input: InputStream): String { - // Read 8-byte little-endian length (number of bytes in the string). - val len = readLittleLong(input) - if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len") - - // Read the UTF-8 bytes of the given length. - val buf = ByteArray(len.toInt()) - if (buf.isNotEmpty()) input.readFully(buf) - return String(buf, Charsets.UTF_8) - } - - /** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */ - private fun littleEndianBytesToInt(bytes: ByteArray): Int { - // Note: assumes bytes length is 4. - return (bytes[3].toInt() and 0xFF shl 24) or - (bytes[2].toInt() and 0xFF shl 16) or - (bytes[1].toInt() and 0xFF shl 8) or - (bytes[0].toInt() and 0xFF) - } - - /** - * Robust skip that works the same on JDK 11 and Android’s desugared runtime. - * - * @param n Number of bytes to advance in the stream. - * @throws IOException on premature EOF. - */ - private fun InputStream.skipFully(n: Long) { - var remaining = n - val scratch = ByteArray(8192) // read‑and‑toss buffer - while (remaining > 0) { - val skipped = skip(remaining) - when { - skipped > 0 -> remaining -= skipped // normal fast path - skipped == 0L -> { - // fallback: read and discard - val read = read(scratch, 0, minOf(remaining, scratch.size.toLong()).toInt()) - if (read == -1) throw IOException("EOF while skipping $n bytes") - remaining -= read - } - else -> throw IOException("Skip returned negative value") - } - } - } - - /** - * Extension that keeps reading until the requested number of bytes are filled. - * Falls back to `read()` when `skip()` returns 0, which happens on some Android - * streams. - * - * @param buf Destination buffer. - * @param len Number of bytes to fill (defaults to `buf.size`). - * @throws IOException on premature EOF. - */ - private fun InputStream.readFully(buf: ByteArray, len: Int = buf.size) { - var off = 0 - while (off < len) { - val n = read(buf, off, len - off) - if (n == -1) throw IOException("EOF after $off of $len bytes") - off += n - } - } - - /** - * Read EXACTLY `n` bytes or throw – never returns a partially‑filled array. - * This is used for small fixed‑length reads (e.g. 4‑byte type codes). - * - * @throws IOException on premature EOF. - */ - private fun InputStream.readNBytesExact(n: Int): ByteArray { - val buf = ByteArray(n) - if (read(buf) != n) throw IOException("Unexpected EOF") - return buf - } -} diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineFactory.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineFactory.kt new file mode 100644 index 0000000000..64d9461a98 --- /dev/null +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineFactory.kt @@ -0,0 +1,13 @@ +package android.llama.cpp.internal + +import android.content.Context +import android.llama.cpp.TierDetection + +/** + * Internal factory to create [InferenceEngine] and [TierDetection] + */ +internal object InferenceEngineFactory { + fun createInstance(context: Context) = InferenceEngineLoader.createInstance(context) + + fun getTierDetection(context: Context): TierDetection = TierDetectionImpl(context) +} diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngineImpl.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineImpl.kt similarity index 80% rename from examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngineImpl.kt rename to examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineImpl.kt index 65954f9388..38325365c6 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngineImpl.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineImpl.kt @@ -1,6 +1,7 @@ -package android.llama.cpp +package android.llama.cpp.internal -import android.llama.cpp.InferenceEngine.State +import android.llama.cpp.InferenceEngine +import android.llama.cpp.LLamaTier import android.util.Log import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CoroutineScope @@ -85,8 +86,9 @@ internal class InferenceEngineImpl private constructor( private external fun unload() private external fun shutdown() - private val _state = MutableStateFlow(State.Uninitialized) - override val state: StateFlow = _state + private val _state = + MutableStateFlow(InferenceEngine.State.Uninitialized) + override val state: StateFlow = _state private var _readyForSystemPrompt = false @@ -100,15 +102,15 @@ internal class InferenceEngineImpl private constructor( init { llamaScope.launch { try { - check(_state.value is State.Uninitialized) { + check(_state.value is InferenceEngine.State.Uninitialized) { "Cannot load native library in ${_state.value.javaClass.simpleName}!" } - _state.value = State.Initializing + _state.value = InferenceEngine.State.Initializing Log.i(TAG, "Loading native library for $tier") System.loadLibrary(tier.libraryName) init() - _state.value = State.Initialized + _state.value = InferenceEngine.State.Initialized Log.i(TAG, "Native library loaded! System info: \n${systemInfo()}") } catch (e: Exception) { @@ -123,7 +125,7 @@ internal class InferenceEngineImpl private constructor( */ override suspend fun loadModel(pathToModel: String) = withContext(llamaDispatcher) { - check(_state.value is State.Initialized) { + check(_state.value is InferenceEngine.State.Initialized) { "Cannot load model in ${_state.value.javaClass.simpleName}!" } File(pathToModel).let { @@ -133,7 +135,7 @@ internal class InferenceEngineImpl private constructor( Log.i(TAG, "Loading model... \n$pathToModel") _readyForSystemPrompt = false - _state.value = State.LoadingModel + _state.value = InferenceEngine.State.LoadingModel load(pathToModel).let { result -> if (result != 0) throw IllegalStateException("Failed to Load model: $result") } @@ -142,7 +144,7 @@ internal class InferenceEngineImpl private constructor( } Log.i(TAG, "Model loaded!") _readyForSystemPrompt = true - _state.value = State.ModelReady + _state.value = InferenceEngine.State.ModelReady } /** @@ -154,40 +156,40 @@ internal class InferenceEngineImpl private constructor( withContext(llamaDispatcher) { require(prompt.isNotBlank()) { "Cannot process empty system prompt!" } check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" } - check(_state.value is State.ModelReady) { + check(_state.value is InferenceEngine.State.ModelReady) { "Cannot process system prompt in ${_state.value.javaClass.simpleName}!" } Log.i(TAG, "Sending system prompt...") _readyForSystemPrompt = false - _state.value = State.ProcessingSystemPrompt + _state.value = InferenceEngine.State.ProcessingSystemPrompt processSystemPrompt(prompt).let { result -> if (result != 0) { val errorMessage = "Failed to process system prompt: $result" - _state.value = State.Error(errorMessage) + _state.value = InferenceEngine.State.Error(errorMessage) throw IllegalStateException(errorMessage) } } Log.i(TAG, "System prompt processed! Awaiting user prompt...") - _state.value = State.ModelReady + _state.value = InferenceEngine.State.ModelReady } /** - * Send plain text user prompt to LLM, which starts generating tokens in a [Flow] + * Send plain text user prompt to LLM, which starts generating tokens in a [kotlinx.coroutines.flow.Flow] */ override fun sendUserPrompt( message: String, predictLength: Int, ): Flow = flow { require(message.isNotEmpty()) { "User prompt discarded due to being empty!" } - check(_state.value is State.ModelReady) { + check(_state.value is InferenceEngine.State.ModelReady) { "User prompt discarded due to: ${_state.value.javaClass.simpleName}" } try { Log.i(TAG, "Sending user prompt...") _readyForSystemPrompt = false - _state.value = State.ProcessingUserPrompt + _state.value = InferenceEngine.State.ProcessingUserPrompt processUserPrompt(message, predictLength).let { result -> if (result != 0) { @@ -197,21 +199,21 @@ internal class InferenceEngineImpl private constructor( } Log.i(TAG, "User prompt processed. Generating assistant prompt...") - _state.value = State.Generating + _state.value = InferenceEngine.State.Generating while (true) { generateNextToken()?.let { utf8token -> if (utf8token.isNotEmpty()) emit(utf8token) } ?: break } Log.i(TAG, "Assistant generation complete. Awaiting user prompt...") - _state.value = State.ModelReady + _state.value = InferenceEngine.State.ModelReady } catch (e: CancellationException) { Log.i(TAG, "Generation cancelled by user.") - _state.value = State.ModelReady + _state.value = InferenceEngine.State.ModelReady throw e } catch (e: Exception) { Log.e(TAG, "Error during generation!", e) - _state.value = State.Error(e.message ?: "Unknown error") + _state.value = InferenceEngine.State.Error(e.message ?: "Unknown error") throw e } }.flowOn(llamaDispatcher) @@ -221,14 +223,14 @@ internal class InferenceEngineImpl private constructor( */ override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String = withContext(llamaDispatcher) { - check(_state.value is State.ModelReady) { + check(_state.value is InferenceEngine.State.ModelReady) { "Benchmark request discarded due to: $state" } Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)") _readyForSystemPrompt = false // Just to be safe - _state.value = State.Benchmarking + _state.value = InferenceEngine.State.Benchmarking benchModel(pp, tg, pl, nr).also { - _state.value = State.ModelReady + _state.value = InferenceEngine.State.ModelReady } } @@ -237,18 +239,19 @@ internal class InferenceEngineImpl private constructor( */ override suspend fun unloadModel() = withContext(llamaDispatcher) { - when(val state = _state.value) { - is State.ModelReady, is State.Error -> { + when (val state = _state.value) { + is InferenceEngine.State.ModelReady, is InferenceEngine.State.Error -> { Log.i(TAG, "Unloading model and free resources...") _readyForSystemPrompt = false - _state.value = State.UnloadingModel + _state.value = InferenceEngine.State.UnloadingModel unload() - _state.value = State.Initialized + _state.value = InferenceEngine.State.Initialized Log.i(TAG, "Model unloaded!") Unit } + else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}") } } @@ -260,8 +263,8 @@ internal class InferenceEngineImpl private constructor( _readyForSystemPrompt = false llamaScope.cancel() when(_state.value) { - is State.Uninitialized -> {} - is State.Initialized -> shutdown() + is InferenceEngine.State.Uninitialized -> {} + is InferenceEngine.State.Initialized -> shutdown() else -> { unload(); shutdown() } } } diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineLoader.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineLoader.kt new file mode 100644 index 0000000000..c23285147f --- /dev/null +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineLoader.kt @@ -0,0 +1,165 @@ +package android.llama.cpp.internal + +import android.content.Context +import android.llama.cpp.InferenceEngine +import android.llama.cpp.LLamaTier +import android.util.Log +import androidx.datastore.core.DataStore +import androidx.datastore.preferences.core.Preferences +import androidx.datastore.preferences.core.edit +import androidx.datastore.preferences.core.intPreferencesKey +import androidx.datastore.preferences.preferencesDataStore +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.runBlocking + +/** + * Internal [android.llama.cpp.InferenceEngine] loader implementation + */ +internal object InferenceEngineLoader { + private val TAG = InferenceEngineLoader::class.simpleName + + // CPU feature detection preferences + private const val DATASTORE_CPU_DETECTION = "llama_cpu_detection" + private val Context.llamaTierDataStore: DataStore + by preferencesDataStore(name = DATASTORE_CPU_DETECTION) + + private val DETECTION_VERSION = intPreferencesKey("detection_version") + private val DETECTED_TIER = intPreferencesKey("detected_tier") + + // Constants + private const val DATASTORE_VERSION = 1 + + @JvmStatic + private external fun getOptimalTier(): Int + + @JvmStatic + private external fun getCpuFeaturesString(): String + + private var _cachedInstance: InferenceEngineImpl? = null + private var _detectedTier: LLamaTier? = null + val detectedTier: LLamaTier? get() = _detectedTier + + /** + * Factory method to get a configured [InferenceEngineImpl] instance. + * Handles tier detection, caching, and library loading automatically. + */ + @Synchronized + fun createInstance(context: Context): InferenceEngine? { + // Return cached instance if available + _cachedInstance?.let { return it } + + return runBlocking { + try { + // Obtain the optimal tier from cache if available + val tier = getOrDetectOptimalTier(context) ?: run { + Log.e(TAG, "Failed to determine optimal tier") + return@runBlocking null + } + _detectedTier = tier + Log.i(TAG, "Using tier: ${tier.name} (${tier.description})") + + // Create and cache the inference engine instance + val instance = InferenceEngineImpl.createWithTier(tier) ?: run { + Log.e(TAG, "Failed to instantiate InferenceEngineImpl") + return@runBlocking null + } + _cachedInstance = instance + Log.i(TAG, "Successfully created InferenceEngineImpl instance with ${tier.name}") + + instance + + } catch (e: Exception) { + Log.e(TAG, "Error creating InferenceEngineImpl instance", e) + null + } + } + } + + /** + * Clear cached detection results (for testing/debugging) + */ + fun clearCache(context: Context) { + runBlocking { context.llamaTierDataStore.edit { it.clear() } } + _cachedInstance = null + _detectedTier = null + Log.i(TAG, "Cleared detection results and cached instance") + } + + /** + * Get optimal tier from cache or detect it fresh + */ + private suspend fun getOrDetectOptimalTier(context: Context): LLamaTier? { + val preferences = context.llamaTierDataStore.data.first() + + // Check if we have a cached result with the current detection version + val cachedVersion = preferences[DETECTION_VERSION] ?: -1 + val cachedTierValue = preferences[DETECTED_TIER] ?: -1 + if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) { + val cachedTier = LLamaTier.Companion.fromRawValue(cachedTierValue) + if (cachedTier != null) { + Log.i(TAG, "Using cached tier detection: ${cachedTier.name}") + return cachedTier + } + } + + // No valid cache, detect fresh + Log.i(TAG, "Performing fresh tier detection") + return detectAndCacheOptimalTier(context) + } + + /** + * Detect optimal tier and save to cache + */ + private suspend fun detectAndCacheOptimalTier(context: Context): LLamaTier? { + try { + // Load CPU detection library + System.loadLibrary("llama_cpu_detector") + Log.i(TAG, "CPU feature detector loaded successfully") + + // Detect optimal tier + val tierValue = getOptimalTier() + val features = getCpuFeaturesString() + Log.i(TAG, "Raw tier $tierValue w/ CPU features: $features") + + // Convert to enum and validate + val tier = LLamaTier.Companion.fromRawValue(tierValue) ?: run { + Log.w(TAG, "Invalid tier value $tierValue") + return null + } + + // Ensure we don't exceed maximum supported tier + val finalTier = if (tier.rawValue > LLamaTier.Companion.maxSupportedTier.rawValue) { + Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${LLamaTier.Companion.maxSupportedTier.name}") + LLamaTier.Companion.maxSupportedTier + } else { + tier + } + + // Cache the result + context.llamaTierDataStore.edit { + it[DETECTED_TIER] = finalTier.rawValue + it[DETECTION_VERSION] = DATASTORE_VERSION + } + + Log.i(TAG, "Detected and cached optimal tier: ${finalTier.name}") + return finalTier + + } catch (e: UnsatisfiedLinkError) { + Log.e(TAG, "Failed to load CPU detection library", e) + + // Fallback to T0 and cache it + val fallbackTier = LLamaTier.T0 + context.llamaTierDataStore.edit { + it[DETECTED_TIER] = fallbackTier.rawValue + it[DETECTION_VERSION] = DATASTORE_VERSION + } + + Log.i(TAG, "Using fallback tier: ${fallbackTier.name}") + return fallbackTier + + } catch (e: Exception) { + Log.e(TAG, "Unexpected error during tier detection", e) + return null + } + } +} diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/TierDetectionImpl.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/TierDetectionImpl.kt new file mode 100644 index 0000000000..1d7b731961 --- /dev/null +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/TierDetectionImpl.kt @@ -0,0 +1,15 @@ +package android.llama.cpp.internal + +import android.content.Context +import android.llama.cpp.LLamaTier +import android.llama.cpp.TierDetection + +/** + * Internal tier detection implementation + */ +internal class TierDetectionImpl(private val context: Context) : TierDetection { + override val detectedTier: LLamaTier? + get() = InferenceEngineLoader.detectedTier + + override fun clearCache() = InferenceEngineLoader.clearCache(context) +} diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/gguf/GgufMetadataReaderImpl.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/gguf/GgufMetadataReaderImpl.kt new file mode 100644 index 0000000000..44944d3223 --- /dev/null +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/gguf/GgufMetadataReaderImpl.kt @@ -0,0 +1,568 @@ +package android.llama.cpp.internal.gguf + +import android.llama.cpp.gguf.GgufMetadata +import android.llama.cpp.gguf.GgufMetadataReader +import java.io.File +import java.io.IOException +import java.io.InputStream + + +/** + * Utility class to read GGUF model files and extract metadata key-value pairs. + * This parser reads the header and metadata of a GGUF v3 file (little-endian) and skips tensor data. + */ +internal class GgufMetadataReaderImpl( + private val skipKeys: Set, + private val arraySummariseThreshold: Int, +) : GgufMetadataReader { + companion object { + private const val ARCH_LLAMA = "llama" + } + + /** Enum corresponding to GGUF metadata value types (for convenience and array element typing). */ + enum class MetadataType(val code: Int) { + UINT8(0), INT8(1), UINT16(2), INT16(3), + UINT32(4), INT32(5), FLOAT32(6), BOOL(7), + STRING(8), ARRAY(9), UINT64(10), INT64(11), FLOAT64(12); + companion object { + private val codeMap = values().associateBy(MetadataType::code) + fun fromCode(code: Int): MetadataType = codeMap[code] + ?: throw IOException("Unknown metadata value type code: $code") + } + } + + /** Sealed class hierarchy for metadata values, providing type-safe representations for each GGUF metadata type. */ + sealed class MetadataValue { + data class UInt8(val value: UByte) : MetadataValue() // 0: 8-bit unsigned int + data class Int8(val value: Byte) : MetadataValue() // 1: 8-bit signed int + data class UInt16(val value: UShort) : MetadataValue() // 2: 16-bit unsigned int (little-endian) + data class Int16(val value: Short) : MetadataValue() // 3: 16-bit signed int (little-endian) + data class UInt32(val value: UInt) : MetadataValue() // 4: 32-bit unsigned int (little-endian) + data class Int32(val value: Int) : MetadataValue() // 5: 32-bit signed int (little-endian) + data class Float32(val value: Float) : MetadataValue() // 6: 32-bit IEEE754 float + data class Bool(val value: Boolean) : MetadataValue() // 7: Boolean (1-byte, 0=false, 1=true) + data class StringVal(val value: String) : MetadataValue() // 8: UTF-8 string (length-prefixed) + data class ArrayVal(val elementType: MetadataType, val elements: List) : MetadataValue() + data class UInt64(val value: ULong) : MetadataValue() // 10: 64-bit unsigned int (little-endian) + data class Int64(val value: Long) : MetadataValue() // 11: 64-bit signed int (little-endian) + data class Float64(val value: Double) : MetadataValue() // 12: 64-bit IEEE754 double + } + + /* Convert MetadataValue to plain Kotlin primitives for allMetadata map */ + private fun MetadataValue.toPrimitive(): Any = when (this) { + is MetadataValue.UInt8 -> value + is MetadataValue.Int8 -> value + is MetadataValue.UInt16 -> value + is MetadataValue.Int16 -> value + is MetadataValue.UInt32 -> value + is MetadataValue.Int32 -> value + is MetadataValue.Float32 -> value + is MetadataValue.Bool -> value + is MetadataValue.StringVal -> value + is MetadataValue.UInt64 -> value + is MetadataValue.Int64 -> value + is MetadataValue.Float64 -> value + is MetadataValue.ArrayVal -> elements.map { it.toPrimitive() } + } + + /** + * High‑level entry point: parses a `.gguf` file on disk and returns the fully + * populated [GgufMetadata] tree. + * + * Steps performed internally: + * 1. Reads and validates the 8‑byte header (`"GGUF"` magic + version). + * 2. Streams through the key‑value section, skipping large blobs if the key + * appears in [skipKeys] or if an array exceeds [arraySummariseThreshold]. + * 3. Converts the resulting raw map into strongly‑typed sub‑structures + * (basic info, tokenizer, rope, etc.). + * + * The method is STREAMING‑ONLY: tensors are never mapped or loaded into + * memory, so even multi‑GB model files can be processed in < 50 ms. + * + * @param path Absolute or relative filesystem path to a `.gguf` file. + * @return A [GgufMetadata] instance containing all recognised metadata plus + * an `allMetadata` map with any keys that were not given a dedicated + * field. + * @throws IOException if the file is not GGUF, the version is unsupported, + * or the metadata block is truncated / corrupt. + */ + override suspend fun readStructuredMetadata(path: String): GgufMetadata { + File(path).inputStream().buffered().use { input -> + // ── 1. header ────────────────────────────────────────────────────────── + // throws on mismatch + val version = ensureMagicAndVersion(input) + val tensorCount = readLittleLong(input) + val kvCount = readLittleLong(input) + + // ── 2. metadata map (reuse our raw parser, but we need access to the stream) ── + val meta = readMetaMap(input, kvCount) // + + // ── 3. build structured object ──────────────────────────────────────── + return buildStructured(meta, version, tensorCount, kvCount) + } + } + + /** Reads the 4‑byte magic + 4‑byte version; throws if magic ≠ "GGUF". */ + private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion { + val magic = ByteArray(4) + if (input.read(magic) != 4) throw IOException("File too short (no magic)") + if (!magic.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46))) // "GGUF" + throw IOException("Not a GGUF file (bad magic)") + return GgufMetadata.GgufVersion.fromCode(readLEUInt32(input)) + } + + /** + * Read an unsigned 32‑bit little‑endian integer. + * + * @throws IOException if fewer than four bytes are available. + */ + private fun readLEUInt32(input: InputStream): Int { + val b0 = input.read(); val b1 = input.read(); val b2 = input.read(); val b3 = input.read() + if (b3 == -1) throw IOException("Unexpected EOF while reading UInt32") + return (b3 and 0xFF shl 24) or + (b2 and 0xFF shl 16) or + (b1 and 0xFF shl 8) or + (b0 and 0xFF) + } + + /** + * Low‑level helper that reads the entire “key-value” section from the current + * stream position. + * + * @param input Open stream positioned JUST AFTER the header. + * @param kvCnt Number of key‑value pairs (taken from the header). + * @return Mutable map with one [MetadataValue] for every key that is NOT skipped. + * + * The function honours [skipKeys] and [arraySummariseThreshold] by invoking + * [skipValue] or [parseValue] accordingly. + */ + private fun readMetaMap(input: InputStream, kvCnt: Long): Map { + val map = mutableMapOf() + repeat(kvCnt.toInt()) { + val key = readString(input) + val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4))) + if (key in skipKeys) { + skipValue(input, valueT) + } else { + map[key] = parseValue(input, valueT) + } + } + return map + } + + /** + * Converts a flat [Map]<[String], [MetadataValue]> into the strongly‑typed + * [GgufMetadata] tree used by the rest of the app. + * + * Only the keys listed in the spec are copied into dedicated data classes; + * everything else is preserved in `GgufMetadata.allMetadata`. + * + * @param m Raw key/value map. + * @param version GGUF file‑format version (enum). + * @param tensorCnt Number of tensors (from the header). + * @param kvCnt Total metadata pair count (from the header). + */ + private fun buildStructured( + m: Map, + version: GgufMetadata.GgufVersion, + tensorCnt: Long, + kvCnt: Long + ): GgufMetadata { + // ---------- helpers ---------- + fun String.str() = (m[this] as? MetadataValue.StringVal)?.value + fun String.bool() = (m[this] as? MetadataValue.Bool)?.value + fun String.i32() = (m[this] as? MetadataValue.Int32)?.value + fun String.u32() = (m[this] as? MetadataValue.UInt32)?.value?.toInt() + fun String.f32() = (m[this] as? MetadataValue.Float32)?.value + fun String.f64() = (m[this] as? MetadataValue.Float64)?.value?.toFloat() + fun String.strList(): List? = + (m[this] as? MetadataValue.ArrayVal) + ?.elements + ?.mapNotNull { (it as? MetadataValue.StringVal)?.value } + + val arch = "general.architecture".str() ?: ARCH_LLAMA + + // -------------- populate sections ---------------- + val basic = GgufMetadata.BasicInfo( + uuid = "general.uuid".str(), + name = "general.basename".str(), + nameLabel = "general.name".str(), + sizeLabel = "general.size_label".str() + ) + + val author = GgufMetadata.AuthorInfo( + organization = "general.organization".str(), + author = "general.author".str(), + doi = "general.doi".str(), + url = "general.url".str(), + repoUrl = "general.repo_url".str(), + license = "general.license".str(), + licenseLink = "general.license.link".str() + ).takeUnless { + organization == null && author == null && doi == null && + url == null && repoUrl == null && license == null && licenseLink == null + } + + val additional = GgufMetadata.AdditionalInfo( + type = "general.type".str(), + description = "general.description".str(), + tags = "general.tags".strList(), + languages = "general.languages".strList() + ).takeUnless { + type == null && description == null && tags == null && languages == null + } + + val architectureInfo = GgufMetadata.ArchitectureInfo( + architecture = arch, + fileType = "general.file_type".u32(), + vocabSize = "$arch.vocab_size".u32(), + finetune = "general.finetune".str(), + quantizationVersion = "general.quantization_version".u32() + ).takeUnless { fileType == null && vocabSize == null && finetune == null && quantizationVersion == null } + + val baseModels = buildList { + val n = "general.base_model.count".u32() ?: 0 + for (i in 0 until n) { + fun k(s: String) = "general.base_model.$i.$s" + add( + GgufMetadata.BaseModelInfo( + name = k("name").str(), + author = k("author").str(), + version = k("version").str(), + organization = k("organization").str(), + url = k("url").str(), + doi = k("doi").str(), + uuid = k("uuid").str(), + repoUrl = k("repo_url").str(), + ) + ) + } + }.takeIf { it.isNotEmpty() } + + val tokenizer = GgufMetadata.TokenizerInfo( + model = "tokenizer.ggml.model".str(), + bosTokenId = "tokenizer.ggml.bos_token_id".u32(), + eosTokenId = "tokenizer.ggml.eos_token_id".u32(), + unknownTokenId = "tokenizer.ggml.unknown_token_id".u32(), + paddingTokenId = "tokenizer.ggml.padding_token_id".u32(), + addBosToken = "tokenizer.ggml.add_bos_token".bool(), + addEosToken = "tokenizer.ggml.add_eos_token".bool(), + chatTemplate = "tokenizer.chat_template".str() + ).takeUnless { model == null && bosTokenId == null && eosTokenId == null && + unknownTokenId == null && paddingTokenId == null && + addBosToken == null && addEosToken == null && chatTemplate == null + } + + val dimensions = GgufMetadata.DimensionsInfo( + contextLength = "$arch.context_length".u32(), + embeddingSize = "$arch.embedding_length".u32(), + blockCount = "$arch.block_count".u32(), + feedForwardSize = "$arch.feed_forward_length".u32() + ).takeUnless { contextLength == null && embeddingSize == null && blockCount == null && feedForwardSize == null } + + val attention = GgufMetadata.AttentionInfo( + headCount = "$arch.attention.head_count".u32(), + headCountKv = "$arch.attention.head_count_kv".u32(), + keyLength = "$arch.attention.key_length".u32(), + valueLength = "$arch.attention.value_length".u32(), + layerNormEpsilon = "$arch.attention.layer_norm_epsilon".f32(), + layerNormRmsEpsilon = "$arch.attention.layer_norm_rms_epsilon".f32(), + ).takeUnless { headCount == null && headCountKv == null && keyLength == null && valueLength == null && + layerNormEpsilon == null && layerNormRmsEpsilon == null + } + + val rope = GgufMetadata.RopeInfo( + frequencyBase = "$arch.rope.freq_base".f32(), + dimensionCount = "$arch.rope.dimension_count".u32(), + scalingType = "$arch.rope.scaling.type".str(), + scalingFactor = "$arch.rope.scaling.factor".f32(), + attnFactor = "$arch.rope.scaling.attn_factor".f32(), + originalContextLength = "$arch.rope.scaling.original_context_length".u32(), + finetuned = "$arch.rope.scaling.finetuned".bool() + ).takeUnless { frequencyBase == null && dimensionCount == null && + scalingType == null && scalingFactor == null && attnFactor == null && + originalContextLength == null && finetuned == null + } + + val experts = GgufMetadata.ExpertsInfo( + count = "$arch.expert_count".u32(), + usedCount = "$arch.expert_used_count".u32() + ).takeUnless { count == null && usedCount == null } + + return GgufMetadata( + version = version, + tensorCount = tensorCnt, + kvCount = kvCnt, + basic = basic, + author = author, + additional = additional, + architecture = architectureInfo, + baseModels = baseModels, + tokenizer = tokenizer, + dimensions = dimensions, + attention = attention, + rope = rope, + experts = experts + ) + } + + /** + * Recursively parses a metadata value of the given type from the input stream. + * @param input The input stream positioned at the start of the value. + * @param type The metadata value type to parse. + */ + private fun parseValue(input: InputStream, type: MetadataType): MetadataValue = when (type) { + MetadataType.UINT8 -> { + // 1-byte unsigned integer + val byteVal = input.read() + if (byteVal == -1) throw IOException("Unexpected EOF while reading uint8 value.") + MetadataValue.UInt8(byteVal.toUByte()) + } + MetadataType.INT8 -> { + // 1-byte signed integer + val byteVal = input.read() + if (byteVal == -1) throw IOException("Unexpected EOF while reading int8 value.") + MetadataValue.Int8(byteVal.toByte()) + } + MetadataType.UINT16 -> { + // 2-byte unsigned integer (little-endian) + val bytes = ByteArray(2) + if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading uint16 value.") + // Combine two bytes (little-endian) into an unsigned 16-bit value + val u16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF) + MetadataValue.UInt16(u16.toUShort()) + } + MetadataType.INT16 -> { + // 2-byte signed integer (little-endian) + val bytes = ByteArray(2) + if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading int16 value.") + // Combine to 16-bit and interpret as signed + val i16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF) + MetadataValue.Int16(i16.toShort()) + } + MetadataType.UINT32 -> { + // 4-byte unsigned integer (little-endian) + val bytes = ByteArray(4) + if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading uint32 value.") + // Combine four bytes into a 32-bit value (as Long to avoid overflow), then convert to UInt + val u32 = (bytes[3].toLong() and 0xFFL shl 24) or + (bytes[2].toLong() and 0xFFL shl 16) or + (bytes[1].toLong() and 0xFFL shl 8) or + (bytes[0].toLong() and 0xFFL) + MetadataValue.UInt32(u32.toUInt()) + } + MetadataType.INT32 -> { + // 4-byte signed integer (little-endian) + val bytes = ByteArray(4) + if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading int32 value.") + // Combine four bytes into a 32-bit signed int + val i32 = (bytes[3].toInt() and 0xFF shl 24) or + (bytes[2].toInt() and 0xFF shl 16) or + (bytes[1].toInt() and 0xFF shl 8) or + (bytes[0].toInt() and 0xFF) + MetadataValue.Int32(i32) + } + MetadataType.FLOAT32 -> { + // 4-byte IEEE 754 float (little-endian) + val bytes = ByteArray(4) + if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading float32 value.") + // Assemble 4 bytes into a 32-bit int bit-pattern, then convert to Float + val bits = (bytes[3].toInt() and 0xFF shl 24) or + (bytes[2].toInt() and 0xFF shl 16) or + (bytes[1].toInt() and 0xFF shl 8) or + (bytes[0].toInt() and 0xFF) + val floatVal = Float.fromBits(bits) + MetadataValue.Float32(floatVal) + } + MetadataType.BOOL -> { + // 1-byte boolean (0 = false, 1 = true) + val byteVal = input.read() + if (byteVal == -1) throw IOException("Unexpected EOF while reading boolean value.") + if (byteVal != 0 && byteVal != 1) { + throw IOException("Invalid boolean value: $byteVal (must be 0 or 1).") + } + MetadataValue.Bool(byteVal != 0) + } + MetadataType.STRING -> { + // UTF-8 string (length-prefixed with 8-byte length) + val str = readString(input) + MetadataValue.StringVal(str) + } + MetadataType.ARRAY -> { + val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4))) + val len = readLittleLong(input) + val count = len.toInt() + + if (arraySummariseThreshold >= 0 && count > arraySummariseThreshold) { + // fast‑forward without allocation + repeat(count) { skipValue(input, elemType) } + MetadataValue.StringVal("Array($elemType, $count items) /* summarised */") + } else { + val list = ArrayList(count) + repeat(count) { list += parseValue(input, elemType) } + MetadataValue.ArrayVal(elemType, list) + } + } + MetadataType.UINT64 -> { + // 8-byte unsigned integer (little-endian) + val bytes = ByteArray(8) + if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading uint64 value.") + // Combine 8 bytes into an unsigned 64-bit (ULong). Use ULong for full 0 to 2^64-1 range. + val u64 = (bytes[7].toULong() and 0xFFuL shl 56) or + (bytes[6].toULong() and 0xFFuL shl 48) or + (bytes[5].toULong() and 0xFFuL shl 40) or + (bytes[4].toULong() and 0xFFuL shl 32) or + (bytes[3].toULong() and 0xFFuL shl 24) or + (bytes[2].toULong() and 0xFFuL shl 16) or + (bytes[1].toULong() and 0xFFuL shl 8) or + (bytes[0].toULong() and 0xFFuL) + MetadataValue.UInt64(u64) + } + MetadataType.INT64 -> { + // 8-byte signed integer (little-endian) + val bytes = ByteArray(8) + if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading int64 value.") + // Combine 8 bytes into a signed 64-bit value (Long) + val i64 = (bytes[7].toLong() and 0xFFL shl 56) or + (bytes[6].toLong() and 0xFFL shl 48) or + (bytes[5].toLong() and 0xFFL shl 40) or + (bytes[4].toLong() and 0xFFL shl 32) or + (bytes[3].toLong() and 0xFFL shl 24) or + (bytes[2].toLong() and 0xFFL shl 16) or + (bytes[1].toLong() and 0xFFL shl 8) or + (bytes[0].toLong() and 0xFFL) + MetadataValue.Int64(i64) + } + MetadataType.FLOAT64 -> { + // 8-byte IEEE 754 double (little-endian) + val bytes = ByteArray(8) + if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading float64 value.") + // Assemble 8 bytes into a 64-bit bit-pattern, then convert to Double + val bits = (bytes[7].toLong() and 0xFFL shl 56) or + (bytes[6].toLong() and 0xFFL shl 48) or + (bytes[5].toLong() and 0xFFL shl 40) or + (bytes[4].toLong() and 0xFFL shl 32) or + (bytes[3].toLong() and 0xFFL shl 24) or + (bytes[2].toLong() and 0xFFL shl 16) or + (bytes[1].toLong() and 0xFFL shl 8) or + (bytes[0].toLong() and 0xFFL) + val doubleVal = Double.fromBits(bits) + MetadataValue.Float64(doubleVal) + } + } + + + private fun T?.takeUnless(check: T.() -> Boolean): T? = + this?.takeIf { !it.check() } + + /** Helper: Skip a value in the stream without storing it (still maintains pointer). */ + private fun skipValue(input: InputStream, type: MetadataType) { + when (type) { + MetadataType.UINT8, MetadataType.INT8, MetadataType.BOOL -> input.skipFully(1) + MetadataType.UINT16, MetadataType.INT16 -> input.skipFully(2) + MetadataType.UINT32, MetadataType.INT32, MetadataType.FLOAT32 -> input.skipFully(4) + MetadataType.UINT64, MetadataType.INT64, MetadataType.FLOAT64 -> input.skipFully(8) + MetadataType.STRING -> { + val len = readLittleLong(input); input.skipFully(len) + } + MetadataType.ARRAY -> { + val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4))) + val len = readLittleLong(input) + repeat(len.toInt()) { skipValue(input, elemType) } // recursive skip + } + } + } + + /** Helper: Read an 8-byte little-endian unsigned value and return it as a signed Long (assuming it fits in 63 bits). */ + private fun readLittleLong(input: InputStream): Long { + val bytes = ByteArray(8) + input.readFully(bytes) + + // Combine 8 bytes into a 64-bit value (Little Endian). + // Note: If the value exceeds Long.MAX_VALUE (bit 63 is 1), this will produce a negative Long (two's complement). + // In our context (lengths/counts), such extremely large values are not expected. + return (bytes[7].toLong() and 0xFFL shl 56) or + (bytes[6].toLong() and 0xFFL shl 48) or + (bytes[5].toLong() and 0xFFL shl 40) or + (bytes[4].toLong() and 0xFFL shl 32) or + (bytes[3].toLong() and 0xFFL shl 24) or + (bytes[2].toLong() and 0xFFL shl 16) or + (bytes[1].toLong() and 0xFFL shl 8) or + (bytes[0].toLong() and 0xFFL) + } + + /** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */ + private fun readString(input: InputStream): String { + // Read 8-byte little-endian length (number of bytes in the string). + val len = readLittleLong(input) + if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len") + + // Read the UTF-8 bytes of the given length. + val buf = ByteArray(len.toInt()) + if (buf.isNotEmpty()) input.readFully(buf) + return String(buf, Charsets.UTF_8) + } + + /** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */ + private fun littleEndianBytesToInt(bytes: ByteArray): Int { + // Note: assumes bytes length is 4. + return (bytes[3].toInt() and 0xFF shl 24) or + (bytes[2].toInt() and 0xFF shl 16) or + (bytes[1].toInt() and 0xFF shl 8) or + (bytes[0].toInt() and 0xFF) + } + + /** + * Robust skip that works the same on JDK 11 and Android’s desugared runtime. + * + * @param n Number of bytes to advance in the stream. + * @throws IOException on premature EOF. + */ + private fun InputStream.skipFully(n: Long) { + var remaining = n + val scratch = ByteArray(8192) // read‑and‑toss buffer + while (remaining > 0) { + val skipped = skip(remaining) + when { + skipped > 0 -> remaining -= skipped // normal fast path + skipped == 0L -> { + // fallback: read and discard + val read = read(scratch, 0, minOf(remaining, scratch.size.toLong()).toInt()) + if (read == -1) throw IOException("EOF while skipping $n bytes") + remaining -= read + } + else -> throw IOException("Skip returned negative value") + } + } + } + + /** + * Extension that keeps reading until the requested number of bytes are filled. + * Falls back to `read()` when `skip()` returns 0, which happens on some Android + * streams. + * + * @param buf Destination buffer. + * @param len Number of bytes to fill (defaults to `buf.size`). + * @throws IOException on premature EOF. + */ + private fun InputStream.readFully(buf: ByteArray, len: Int = buf.size) { + var off = 0 + while (off < len) { + val n = read(buf, off, len - off) + if (n == -1) throw IOException("EOF after $off of $len bytes") + off += n + } + } + + /** + * Read EXACTLY `n` bytes or throw – never returns a partially‑filled array. + * This is used for small fixed‑length reads (e.g. 4‑byte type codes). + * + * @throws IOException on premature EOF. + */ + private fun InputStream.readNBytesExact(n: Int): ByteArray { + val buf = ByteArray(n) + if (read(buf) != n) throw IOException("Unexpected EOF") + return buf + } +}