lib: hide the internal implementations, only expose a facade and interfaces
This commit is contained in:
parent
57c3a9dda7
commit
c5058366dc
|
|
@ -2,7 +2,7 @@ package com.example.llama.di
|
|||
|
||||
import android.content.Context
|
||||
import android.llama.cpp.InferenceEngine
|
||||
import android.llama.cpp.InferenceEngineLoader
|
||||
import android.llama.cpp.KleidiLlama
|
||||
import android.llama.cpp.gguf.GgufMetadataReader
|
||||
import com.example.llama.data.local.AppDatabase
|
||||
import com.example.llama.data.remote.HuggingFaceApiService
|
||||
|
|
@ -66,7 +66,7 @@ internal abstract class AppModule {
|
|||
return if (USE_STUB_ENGINE) {
|
||||
StubInferenceEngine()
|
||||
} else {
|
||||
InferenceEngineLoader.createInstance(context)
|
||||
KleidiLlama.createInferenceEngine(context)
|
||||
?: throw InstantiationException("Cannot instantiate InferenceEngine!")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ static const Aarch64Info info = GetAarch64Info();
|
|||
static const Aarch64Features features = info.features;
|
||||
|
||||
extern "C" JNIEXPORT jint JNICALL
|
||||
Java_android_llama_cpp_InferenceEngineLoader_getOptimalTier(
|
||||
Java_android_llama_cpp_internal_InferenceEngineLoader_getOptimalTier(
|
||||
JNIEnv* env,
|
||||
jclass clazz) {
|
||||
int tier = 0; // Default to T0 (baseline)
|
||||
|
|
@ -46,7 +46,7 @@ Java_android_llama_cpp_InferenceEngineLoader_getOptimalTier(
|
|||
|
||||
// Optional: Keep a feature string function for debugging
|
||||
extern "C" JNIEXPORT jstring JNICALL
|
||||
Java_android_llama_cpp_InferenceEngineLoader_getCpuFeaturesString(
|
||||
Java_android_llama_cpp_internal_InferenceEngineLoader_getCpuFeaturesString(
|
||||
JNIEnv* env,
|
||||
jclass clazz) {
|
||||
std::string text;
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ static void log_callback(ggml_log_level level, const char *fmt, void *data) {
|
|||
|
||||
extern "C"
|
||||
JNIEXPORT void JNICALL
|
||||
Java_android_llama_cpp_InferenceEngineImpl_init(JNIEnv *env, jobject /*unused*/) {
|
||||
Java_android_llama_cpp_internal_InferenceEngineImpl_init(JNIEnv *env, jobject /*unused*/) {
|
||||
// Set llama log handler to Android
|
||||
llama_log_set(log_callback, nullptr);
|
||||
|
||||
|
|
@ -83,7 +83,7 @@ Java_android_llama_cpp_InferenceEngineImpl_init(JNIEnv *env, jobject /*unused*/)
|
|||
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_android_llama_cpp_InferenceEngineImpl_load(JNIEnv *env, jobject, jstring jmodel_path) {
|
||||
Java_android_llama_cpp_internal_InferenceEngineImpl_load(JNIEnv *env, jobject, jstring jmodel_path) {
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
|
||||
const auto *model_path = env->GetStringUTFChars(jmodel_path, 0);
|
||||
|
|
@ -137,7 +137,7 @@ static common_sampler *new_sampler(float temp) {
|
|||
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_android_llama_cpp_InferenceEngineImpl_prepare(JNIEnv * /*env*/, jobject /*unused*/) {
|
||||
Java_android_llama_cpp_internal_InferenceEngineImpl_prepare(JNIEnv * /*env*/, jobject /*unused*/) {
|
||||
auto *context = init_context(g_model);
|
||||
if (!context) { return 1; }
|
||||
g_context = context;
|
||||
|
|
@ -161,13 +161,13 @@ static std::string get_backend() {
|
|||
|
||||
extern "C"
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_android_llama_cpp_InferenceEngineImpl_systemInfo(JNIEnv *env, jobject /*unused*/) {
|
||||
Java_android_llama_cpp_internal_InferenceEngineImpl_systemInfo(JNIEnv *env, jobject /*unused*/) {
|
||||
return env->NewStringUTF(llama_print_system_info());
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_android_llama_cpp_InferenceEngineImpl_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg,
|
||||
Java_android_llama_cpp_internal_InferenceEngineImpl_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg,
|
||||
jint pl, jint nr) {
|
||||
auto *context = init_context(g_model, pp);
|
||||
if (!context) {
|
||||
|
|
@ -377,7 +377,7 @@ static int decode_tokens_in_batches(
|
|||
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_android_llama_cpp_InferenceEngineImpl_processSystemPrompt(
|
||||
Java_android_llama_cpp_internal_InferenceEngineImpl_processSystemPrompt(
|
||||
JNIEnv *env,
|
||||
jobject /*unused*/,
|
||||
jstring jsystem_prompt
|
||||
|
|
@ -426,7 +426,7 @@ Java_android_llama_cpp_InferenceEngineImpl_processSystemPrompt(
|
|||
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_android_llama_cpp_InferenceEngineImpl_processUserPrompt(
|
||||
Java_android_llama_cpp_internal_InferenceEngineImpl_processUserPrompt(
|
||||
JNIEnv *env,
|
||||
jobject /*unused*/,
|
||||
jstring juser_prompt,
|
||||
|
|
@ -510,7 +510,7 @@ static bool is_valid_utf8(const char *string) {
|
|||
|
||||
extern "C"
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_android_llama_cpp_InferenceEngineImpl_generateNextToken(
|
||||
Java_android_llama_cpp_internal_InferenceEngineImpl_generateNextToken(
|
||||
JNIEnv *env,
|
||||
jobject /*unused*/
|
||||
) {
|
||||
|
|
@ -570,7 +570,7 @@ Java_android_llama_cpp_InferenceEngineImpl_generateNextToken(
|
|||
|
||||
extern "C"
|
||||
JNIEXPORT void JNICALL
|
||||
Java_android_llama_cpp_InferenceEngineImpl_unload(JNIEnv * /*unused*/, jobject /*unused*/) {
|
||||
Java_android_llama_cpp_internal_InferenceEngineImpl_unload(JNIEnv * /*unused*/, jobject /*unused*/) {
|
||||
// Reset long-term & short-term states
|
||||
reset_long_term_states();
|
||||
reset_short_term_states();
|
||||
|
|
@ -585,6 +585,6 @@ Java_android_llama_cpp_InferenceEngineImpl_unload(JNIEnv * /*unused*/, jobject /
|
|||
|
||||
extern "C"
|
||||
JNIEXPORT void JNICALL
|
||||
Java_android_llama_cpp_InferenceEngineImpl_shutdown(JNIEnv *env, jobject /*unused*/) {
|
||||
Java_android_llama_cpp_internal_InferenceEngineImpl_shutdown(JNIEnv *env, jobject /*unused*/) {
|
||||
llama_backend_free();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,179 +0,0 @@
|
|||
package android.llama.cpp
|
||||
|
||||
import android.content.Context
|
||||
import android.util.Log
|
||||
import androidx.datastore.core.DataStore
|
||||
import androidx.datastore.preferences.core.Preferences
|
||||
import androidx.datastore.preferences.core.edit
|
||||
import androidx.datastore.preferences.core.intPreferencesKey
|
||||
import androidx.datastore.preferences.preferencesDataStore
|
||||
import kotlinx.coroutines.flow.first
|
||||
import kotlinx.coroutines.runBlocking
|
||||
|
||||
enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) {
|
||||
T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"),
|
||||
T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"),
|
||||
T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"),
|
||||
T3(3, "llama_android_t3", "ARMv9-a with DotProd + I8MM + SVE/SVE2");
|
||||
// TODO-han.yin: implement T4 once obtaining an Android device with SME!
|
||||
|
||||
companion object {
|
||||
fun fromRawValue(value: Int): LLamaTier? {
|
||||
return entries.find { it.rawValue == value }
|
||||
}
|
||||
|
||||
fun getMaxSupportedTier(): LLamaTier = T3
|
||||
}
|
||||
}
|
||||
|
||||
class InferenceEngineLoader private constructor() {
|
||||
|
||||
companion object {
|
||||
private val TAG = InferenceEngineLoader::class.simpleName
|
||||
|
||||
// CPU feature detection preferences
|
||||
private const val DATASTORE_CPU_DETECTION = "llama_cpu_detection"
|
||||
private val Context.llamaTierDataStore: DataStore<Preferences>
|
||||
by preferencesDataStore(name = DATASTORE_CPU_DETECTION)
|
||||
|
||||
private val DETECTION_VERSION = intPreferencesKey("detection_version")
|
||||
private val DETECTED_TIER = intPreferencesKey("detected_tier")
|
||||
|
||||
// Constants
|
||||
private const val DATASTORE_VERSION = 1
|
||||
|
||||
@JvmStatic
|
||||
private external fun getOptimalTier(): Int
|
||||
|
||||
@JvmStatic
|
||||
private external fun getCpuFeaturesString(): String
|
||||
|
||||
private var _cachedInstance: InferenceEngineImpl? = null
|
||||
private var _detectedTier: LLamaTier? = null
|
||||
val detectedTier: LLamaTier? get() = _detectedTier
|
||||
|
||||
/**
|
||||
* Factory method to get a configured [InferenceEngineImpl] instance.
|
||||
* Handles tier detection, caching, and library loading automatically.
|
||||
*/
|
||||
@Synchronized
|
||||
fun createInstance(context: Context): InferenceEngine? {
|
||||
// Return cached instance if available
|
||||
_cachedInstance?.let { return it }
|
||||
|
||||
return runBlocking {
|
||||
try {
|
||||
// Obtain the optimal tier from cache if available
|
||||
val tier = getOrDetectOptimalTier(context) ?: run {
|
||||
Log.e(TAG, "Failed to determine optimal tier")
|
||||
return@runBlocking null
|
||||
}
|
||||
_detectedTier = tier
|
||||
Log.i(TAG, "Using tier: ${tier.name} (${tier.description})")
|
||||
|
||||
// Create and cache the inference engine instance
|
||||
val instance = InferenceEngineImpl.createWithTier(tier) ?: run {
|
||||
Log.e(TAG, "Failed to instantiate InferenceEngineImpl")
|
||||
return@runBlocking null
|
||||
}
|
||||
_cachedInstance = instance
|
||||
Log.i(TAG, "Successfully created InferenceEngineImpl instance with ${tier.name}")
|
||||
|
||||
instance
|
||||
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error creating InferenceEngineImpl instance", e)
|
||||
null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear cached detection results (for testing/debugging)
|
||||
*/
|
||||
fun clearCache(context: Context) {
|
||||
runBlocking { context.llamaTierDataStore.edit { it.clear() } }
|
||||
_cachedInstance = null
|
||||
_detectedTier = null
|
||||
Log.i(TAG, "Cleared detection results and cached instance")
|
||||
}
|
||||
|
||||
/**
|
||||
* Get optimal tier from cache or detect it fresh
|
||||
*/
|
||||
private suspend fun getOrDetectOptimalTier(context: Context): LLamaTier? {
|
||||
val preferences = context.llamaTierDataStore.data.first()
|
||||
|
||||
// Check if we have a cached result with the current detection version
|
||||
val cachedVersion = preferences[DETECTION_VERSION] ?: -1
|
||||
val cachedTierValue = preferences[DETECTED_TIER] ?: -1
|
||||
if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) {
|
||||
val cachedTier = LLamaTier.fromRawValue(cachedTierValue)
|
||||
if (cachedTier != null) {
|
||||
Log.i(TAG, "Using cached tier detection: ${cachedTier.name}")
|
||||
return cachedTier
|
||||
}
|
||||
}
|
||||
|
||||
// No valid cache, detect fresh
|
||||
Log.i(TAG, "Performing fresh tier detection")
|
||||
return detectAndCacheOptimalTier(context)
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect optimal tier and save to cache
|
||||
*/
|
||||
private suspend fun detectAndCacheOptimalTier(context: Context): LLamaTier? {
|
||||
try {
|
||||
// Load CPU detection library
|
||||
System.loadLibrary("llama_cpu_detector")
|
||||
Log.i(TAG, "CPU feature detector loaded successfully")
|
||||
|
||||
// Detect optimal tier
|
||||
val tierValue = getOptimalTier()
|
||||
val features = getCpuFeaturesString()
|
||||
Log.i(TAG, "Raw tier $tierValue w/ CPU features: $features")
|
||||
|
||||
// Convert to enum and validate
|
||||
val tier = LLamaTier.fromRawValue(tierValue) ?: run {
|
||||
Log.w(TAG, "Invalid tier value $tierValue")
|
||||
return null
|
||||
}
|
||||
|
||||
// Ensure we don't exceed maximum supported tier
|
||||
val finalTier = if (tier.rawValue > LLamaTier.getMaxSupportedTier().rawValue) {
|
||||
Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${LLamaTier.getMaxSupportedTier().name}")
|
||||
LLamaTier.getMaxSupportedTier()
|
||||
} else {
|
||||
tier
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
context.llamaTierDataStore.edit {
|
||||
it[DETECTED_TIER] = finalTier.rawValue
|
||||
it[DETECTION_VERSION] = DATASTORE_VERSION
|
||||
}
|
||||
|
||||
Log.i(TAG, "Detected and cached optimal tier: ${finalTier.name}")
|
||||
return finalTier
|
||||
|
||||
} catch (e: UnsatisfiedLinkError) {
|
||||
Log.e(TAG, "Failed to load CPU detection library", e)
|
||||
|
||||
// Fallback to T0 and cache it
|
||||
val fallbackTier = LLamaTier.T0
|
||||
context.llamaTierDataStore.edit {
|
||||
it[DETECTED_TIER] = fallbackTier.rawValue
|
||||
it[DETECTION_VERSION] = DATASTORE_VERSION
|
||||
}
|
||||
|
||||
Log.i(TAG, "Using fallback tier: ${fallbackTier.name}")
|
||||
return fallbackTier
|
||||
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Unexpected error during tier detection", e)
|
||||
return null
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
package android.llama.cpp
|
||||
|
||||
import android.content.Context
|
||||
import android.llama.cpp.internal.InferenceEngineFactory
|
||||
|
||||
/**
|
||||
* Main entry point for the Llama Android library.
|
||||
* This is the only class that should be used by library consumers.
|
||||
*/
|
||||
object KleidiLlama {
|
||||
/**
|
||||
* Create an inference engine instance with automatic tier detection.
|
||||
*/
|
||||
fun createInferenceEngine(context: Context) =
|
||||
InferenceEngineFactory.createInstance(context)
|
||||
|
||||
/**
|
||||
* Get tier detection information for debugging/settings.
|
||||
*/
|
||||
fun getTierDetection(context: Context) =
|
||||
InferenceEngineFactory.getTierDetection(context)
|
||||
}
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
package android.llama.cpp
|
||||
|
||||
/**
|
||||
* Public interface for [LLamaTier] detection information.
|
||||
*/
|
||||
interface TierDetection {
|
||||
val detectedTier: LLamaTier?
|
||||
fun clearCache()
|
||||
}
|
||||
|
||||
/**
|
||||
* ARM optimization tiers supported by the Kleidi-Llama library.
|
||||
* Higher tiers provide better performance on supported hardware.
|
||||
*/
|
||||
enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) {
|
||||
T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"),
|
||||
T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"),
|
||||
T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"),
|
||||
T3(3, "llama_android_t3", "ARMv9-a with DotProd + I8MM + SVE/SVE2");
|
||||
// TODO-han.yin: implement T4 once obtaining an Android device with SME!
|
||||
|
||||
companion object {
|
||||
fun fromRawValue(value: Int): LLamaTier? = entries.find { it.rawValue == value }
|
||||
|
||||
val maxSupportedTier = T3
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
package android.llama.cpp.gguf
|
||||
|
||||
|
||||
/**
|
||||
* Numerical codes used by `general.file_type` (see llama.cpp repo's `constants.py`).
|
||||
* The `label` matches what llama‑cli prints.
|
||||
*/
|
||||
enum class FileType(val code: Int, val label: String) {
|
||||
ALL_F32(0, "all F32"),
|
||||
MOSTLY_F16(1, "F16"),
|
||||
MOSTLY_Q4_0(2, "Q4_0"),
|
||||
MOSTLY_Q4_1(3, "Q4_1"),
|
||||
// 4 removed
|
||||
MOSTLY_Q8_0(7, "Q8_0"),
|
||||
MOSTLY_Q5_0(8, "Q5_0"),
|
||||
MOSTLY_Q5_1(9, "Q5_1"),
|
||||
|
||||
/* K‑quants ------------------------------------------------------------ */
|
||||
MOSTLY_Q2_K (10, "Q2_K - Medium"),
|
||||
MOSTLY_Q3_K_S (11, "Q3_K - Small"),
|
||||
MOSTLY_Q3_K_M (12, "Q3_K - Medium"),
|
||||
MOSTLY_Q3_K_L (13, "Q3_K - Large"),
|
||||
MOSTLY_Q4_K_S (14, "Q4_K - Small"),
|
||||
MOSTLY_Q4_K_M (15, "Q4_K - Medium"),
|
||||
MOSTLY_Q5_K_S (16, "Q5_K - Small"),
|
||||
MOSTLY_Q5_K_M (17, "Q5_K - Medium"),
|
||||
MOSTLY_Q6_K (18, "Q6_K"),
|
||||
|
||||
/* IQ quants ----------------------------------------------------------- */
|
||||
MOSTLY_IQ2_XXS (19, "IQ2_XXS - 2.06 bpw"),
|
||||
MOSTLY_IQ2_XS (20, "IQ2_XS - 2.31 bpw"),
|
||||
MOSTLY_Q2_K_S (21, "Q2_K - Small"),
|
||||
MOSTLY_IQ3_XS (22, "IQ3_XS - 3.30 bpw"),
|
||||
MOSTLY_IQ3_XXS (23, "IQ3_XXS - 3.06 bpw"),
|
||||
MOSTLY_IQ1_S (24, "IQ1_S - 1.56 bpw"),
|
||||
MOSTLY_IQ4_NL (25, "IQ4_NL - 4.5 bpw"),
|
||||
MOSTLY_IQ3_S (26, "IQ3_S - 3.44 bpw"),
|
||||
MOSTLY_IQ3_M (27, "IQ3_M - 3.66 bpw"),
|
||||
MOSTLY_IQ2_S (28, "IQ2_S - 2.50 bpw"),
|
||||
MOSTLY_IQ2_M (29, "IQ2_M - 2.70 bpw"),
|
||||
MOSTLY_IQ4_XS (30, "IQ4_XS - 4.25 bpw"),
|
||||
MOSTLY_IQ1_M (31, "IQ1_M - 1.75 bpw"),
|
||||
|
||||
/* BF16 & Ternary ------------------------------------------------------ */
|
||||
MOSTLY_BF16 (32, "BF16"),
|
||||
MOSTLY_TQ1_0 (36, "TQ1_0 - 1.69 bpw ternary"),
|
||||
MOSTLY_TQ2_0 (37, "TQ2_0 - 2.06 bpw ternary"),
|
||||
|
||||
/* Special flag -------------------------------------------------------- */
|
||||
GUESSED(1024, "(guessed)"),
|
||||
|
||||
UNKNOWN(-1, "unknown");
|
||||
|
||||
companion object {
|
||||
private val map = entries.associateBy(FileType::code)
|
||||
|
||||
fun fromCode(code: Int?): FileType = map[code] ?: UNKNOWN
|
||||
}
|
||||
}
|
||||
|
|
@ -130,60 +130,3 @@ data class GgufMetadata(
|
|||
val usedCount: Int? = null,
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Numerical codes used by `general.file_type` (see llama.cpp repo's `constants.py`).
|
||||
* The `label` matches what llama‑cli prints.
|
||||
*/
|
||||
enum class FileType(val code: Int, val label: String) {
|
||||
ALL_F32(0, "all F32"),
|
||||
MOSTLY_F16(1, "F16"),
|
||||
MOSTLY_Q4_0(2, "Q4_0"),
|
||||
MOSTLY_Q4_1(3, "Q4_1"),
|
||||
// 4 removed
|
||||
MOSTLY_Q8_0(7, "Q8_0"),
|
||||
MOSTLY_Q5_0(8, "Q5_0"),
|
||||
MOSTLY_Q5_1(9, "Q5_1"),
|
||||
|
||||
/* K‑quants ------------------------------------------------------------ */
|
||||
MOSTLY_Q2_K (10, "Q2_K - Medium"),
|
||||
MOSTLY_Q3_K_S (11, "Q3_K - Small"),
|
||||
MOSTLY_Q3_K_M (12, "Q3_K - Medium"),
|
||||
MOSTLY_Q3_K_L (13, "Q3_K - Large"),
|
||||
MOSTLY_Q4_K_S (14, "Q4_K - Small"),
|
||||
MOSTLY_Q4_K_M (15, "Q4_K - Medium"),
|
||||
MOSTLY_Q5_K_S (16, "Q5_K - Small"),
|
||||
MOSTLY_Q5_K_M (17, "Q5_K - Medium"),
|
||||
MOSTLY_Q6_K (18, "Q6_K"),
|
||||
|
||||
/* IQ quants ----------------------------------------------------------- */
|
||||
MOSTLY_IQ2_XXS (19, "IQ2_XXS - 2.06 bpw"),
|
||||
MOSTLY_IQ2_XS (20, "IQ2_XS - 2.31 bpw"),
|
||||
MOSTLY_Q2_K_S (21, "Q2_K - Small"),
|
||||
MOSTLY_IQ3_XS (22, "IQ3_XS - 3.30 bpw"),
|
||||
MOSTLY_IQ3_XXS (23, "IQ3_XXS - 3.06 bpw"),
|
||||
MOSTLY_IQ1_S (24, "IQ1_S - 1.56 bpw"),
|
||||
MOSTLY_IQ4_NL (25, "IQ4_NL - 4.5 bpw"),
|
||||
MOSTLY_IQ3_S (26, "IQ3_S - 3.44 bpw"),
|
||||
MOSTLY_IQ3_M (27, "IQ3_M - 3.66 bpw"),
|
||||
MOSTLY_IQ2_S (28, "IQ2_S - 2.50 bpw"),
|
||||
MOSTLY_IQ2_M (29, "IQ2_M - 2.70 bpw"),
|
||||
MOSTLY_IQ4_XS (30, "IQ4_XS - 4.25 bpw"),
|
||||
MOSTLY_IQ1_M (31, "IQ1_M - 1.75 bpw"),
|
||||
|
||||
/* BF16 & Ternary ------------------------------------------------------ */
|
||||
MOSTLY_BF16 (32, "BF16"),
|
||||
MOSTLY_TQ1_0 (36, "TQ1_0 - 1.69 bpw ternary"),
|
||||
MOSTLY_TQ2_0 (37, "TQ2_0 - 2.06 bpw ternary"),
|
||||
|
||||
/* Special flag -------------------------------------------------------- */
|
||||
GUESSED(1024, "(guessed)"),
|
||||
|
||||
UNKNOWN(-1, "unknown");
|
||||
|
||||
companion object {
|
||||
private val map = entries.associateBy(FileType::code)
|
||||
|
||||
fun fromCode(code: Int?): FileType = map[code] ?: UNKNOWN
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
package android.llama.cpp.gguf
|
||||
|
||||
import java.io.File
|
||||
import android.llama.cpp.internal.gguf.GgufMetadataReaderImpl
|
||||
import java.io.IOException
|
||||
import java.io.InputStream
|
||||
|
||||
/**
|
||||
* Interface for reading GGUF metadata from model files.
|
||||
|
|
@ -51,563 +50,3 @@ interface GgufMetadataReader {
|
|||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Utility class to read GGUF model files and extract metadata key-value pairs.
|
||||
* This parser reads the header and metadata of a GGUF v3 file (little-endian) and skips tensor data.
|
||||
*/
|
||||
private class GgufMetadataReaderImpl(
|
||||
private val skipKeys: Set<String>,
|
||||
private val arraySummariseThreshold: Int,
|
||||
) : GgufMetadataReader {
|
||||
companion object {
|
||||
private const val ARCH_LLAMA = "llama"
|
||||
}
|
||||
|
||||
/** Enum corresponding to GGUF metadata value types (for convenience and array element typing). */
|
||||
enum class MetadataType(val code: Int) {
|
||||
UINT8(0), INT8(1), UINT16(2), INT16(3),
|
||||
UINT32(4), INT32(5), FLOAT32(6), BOOL(7),
|
||||
STRING(8), ARRAY(9), UINT64(10), INT64(11), FLOAT64(12);
|
||||
companion object {
|
||||
private val codeMap = values().associateBy(MetadataType::code)
|
||||
fun fromCode(code: Int): MetadataType = codeMap[code]
|
||||
?: throw IOException("Unknown metadata value type code: $code")
|
||||
}
|
||||
}
|
||||
|
||||
/** Sealed class hierarchy for metadata values, providing type-safe representations for each GGUF metadata type. */
|
||||
sealed class MetadataValue {
|
||||
data class UInt8(val value: UByte) : MetadataValue() // 0: 8-bit unsigned int
|
||||
data class Int8(val value: Byte) : MetadataValue() // 1: 8-bit signed int
|
||||
data class UInt16(val value: UShort) : MetadataValue() // 2: 16-bit unsigned int (little-endian)
|
||||
data class Int16(val value: Short) : MetadataValue() // 3: 16-bit signed int (little-endian)
|
||||
data class UInt32(val value: UInt) : MetadataValue() // 4: 32-bit unsigned int (little-endian)
|
||||
data class Int32(val value: Int) : MetadataValue() // 5: 32-bit signed int (little-endian)
|
||||
data class Float32(val value: Float) : MetadataValue() // 6: 32-bit IEEE754 float
|
||||
data class Bool(val value: Boolean) : MetadataValue() // 7: Boolean (1-byte, 0=false, 1=true)
|
||||
data class StringVal(val value: String) : MetadataValue() // 8: UTF-8 string (length-prefixed)
|
||||
data class ArrayVal(val elementType: MetadataType, val elements: List<MetadataValue>) : MetadataValue()
|
||||
data class UInt64(val value: ULong) : MetadataValue() // 10: 64-bit unsigned int (little-endian)
|
||||
data class Int64(val value: Long) : MetadataValue() // 11: 64-bit signed int (little-endian)
|
||||
data class Float64(val value: Double) : MetadataValue() // 12: 64-bit IEEE754 double
|
||||
}
|
||||
|
||||
/* Convert MetadataValue to plain Kotlin primitives for allMetadata map */
|
||||
private fun MetadataValue.toPrimitive(): Any = when (this) {
|
||||
is MetadataValue.UInt8 -> value
|
||||
is MetadataValue.Int8 -> value
|
||||
is MetadataValue.UInt16 -> value
|
||||
is MetadataValue.Int16 -> value
|
||||
is MetadataValue.UInt32 -> value
|
||||
is MetadataValue.Int32 -> value
|
||||
is MetadataValue.Float32 -> value
|
||||
is MetadataValue.Bool -> value
|
||||
is MetadataValue.StringVal -> value
|
||||
is MetadataValue.UInt64 -> value
|
||||
is MetadataValue.Int64 -> value
|
||||
is MetadataValue.Float64 -> value
|
||||
is MetadataValue.ArrayVal -> elements.map { it.toPrimitive() }
|
||||
}
|
||||
|
||||
/**
|
||||
* High‑level entry point: parses a `.gguf` file on disk and returns the fully
|
||||
* populated [GgufMetadata] tree.
|
||||
*
|
||||
* Steps performed internally:
|
||||
* 1. Reads and validates the 8‑byte header (`"GGUF"` magic + version).
|
||||
* 2. Streams through the key‑value section, skipping large blobs if the key
|
||||
* appears in [skipKeys] or if an array exceeds [arraySummariseThreshold].
|
||||
* 3. Converts the resulting raw map into strongly‑typed sub‑structures
|
||||
* (basic info, tokenizer, rope, etc.).
|
||||
*
|
||||
* The method is STREAMING‑ONLY: tensors are never mapped or loaded into
|
||||
* memory, so even multi‑GB model files can be processed in < 50 ms.
|
||||
*
|
||||
* @param path Absolute or relative filesystem path to a `.gguf` file.
|
||||
* @return A [GgufMetadata] instance containing all recognised metadata plus
|
||||
* an `allMetadata` map with any keys that were not given a dedicated
|
||||
* field.
|
||||
* @throws IOException if the file is not GGUF, the version is unsupported,
|
||||
* or the metadata block is truncated / corrupt.
|
||||
*/
|
||||
override suspend fun readStructuredMetadata(path: String): GgufMetadata {
|
||||
File(path).inputStream().buffered().use { input ->
|
||||
// ── 1. header ──────────────────────────────────────────────────────────
|
||||
// throws on mismatch
|
||||
val version = ensureMagicAndVersion(input)
|
||||
val tensorCount = readLittleLong(input)
|
||||
val kvCount = readLittleLong(input)
|
||||
|
||||
// ── 2. metadata map (reuse our raw parser, but we need access to the stream) ──
|
||||
val meta = readMetaMap(input, kvCount) // <String, MetadataValue>
|
||||
|
||||
// ── 3. build structured object ────────────────────────────────────────
|
||||
return buildStructured(meta, version, tensorCount, kvCount)
|
||||
}
|
||||
}
|
||||
|
||||
/** Reads the 4‑byte magic + 4‑byte version; throws if magic ≠ "GGUF". */
|
||||
private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion {
|
||||
val magic = ByteArray(4)
|
||||
if (input.read(magic) != 4) throw IOException("File too short (no magic)")
|
||||
if (!magic.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46))) // "GGUF"
|
||||
throw IOException("Not a GGUF file (bad magic)")
|
||||
return GgufMetadata.GgufVersion.fromCode(readLEUInt32(input))
|
||||
}
|
||||
|
||||
/**
|
||||
* Read an unsigned 32‑bit little‑endian integer.
|
||||
*
|
||||
* @throws IOException if fewer than four bytes are available.
|
||||
*/
|
||||
private fun readLEUInt32(input: InputStream): Int {
|
||||
val b0 = input.read(); val b1 = input.read(); val b2 = input.read(); val b3 = input.read()
|
||||
if (b3 == -1) throw IOException("Unexpected EOF while reading UInt32")
|
||||
return (b3 and 0xFF shl 24) or
|
||||
(b2 and 0xFF shl 16) or
|
||||
(b1 and 0xFF shl 8) or
|
||||
(b0 and 0xFF)
|
||||
}
|
||||
|
||||
/**
|
||||
* Low‑level helper that reads the entire “key-value” section from the current
|
||||
* stream position.
|
||||
*
|
||||
* @param input Open stream positioned JUST AFTER the header.
|
||||
* @param kvCnt Number of key‑value pairs (taken from the header).
|
||||
* @return Mutable map with one [MetadataValue] for every key that is NOT skipped.
|
||||
*
|
||||
* The function honours [skipKeys] and [arraySummariseThreshold] by invoking
|
||||
* [skipValue] or [parseValue] accordingly.
|
||||
*/
|
||||
private fun readMetaMap(input: InputStream, kvCnt: Long): Map<String, MetadataValue> {
|
||||
val map = mutableMapOf<String, MetadataValue>()
|
||||
repeat(kvCnt.toInt()) {
|
||||
val key = readString(input)
|
||||
val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
|
||||
if (key in skipKeys) {
|
||||
skipValue(input, valueT)
|
||||
} else {
|
||||
map[key] = parseValue(input, valueT)
|
||||
}
|
||||
}
|
||||
return map
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a flat [Map]<[String], [MetadataValue]> into the strongly‑typed
|
||||
* [GgufMetadata] tree used by the rest of the app.
|
||||
*
|
||||
* Only the keys listed in the spec are copied into dedicated data classes;
|
||||
* everything else is preserved in `GgufMetadata.allMetadata`.
|
||||
*
|
||||
* @param m Raw key/value map.
|
||||
* @param version GGUF file‑format version (enum).
|
||||
* @param tensorCnt Number of tensors (from the header).
|
||||
* @param kvCnt Total metadata pair count (from the header).
|
||||
*/
|
||||
private fun buildStructured(
|
||||
m: Map<String, MetadataValue>,
|
||||
version: GgufMetadata.GgufVersion,
|
||||
tensorCnt: Long,
|
||||
kvCnt: Long
|
||||
): GgufMetadata {
|
||||
// ---------- helpers ----------
|
||||
fun String.str() = (m[this] as? MetadataValue.StringVal)?.value
|
||||
fun String.bool() = (m[this] as? MetadataValue.Bool)?.value
|
||||
fun String.i32() = (m[this] as? MetadataValue.Int32)?.value
|
||||
fun String.u32() = (m[this] as? MetadataValue.UInt32)?.value?.toInt()
|
||||
fun String.f32() = (m[this] as? MetadataValue.Float32)?.value
|
||||
fun String.f64() = (m[this] as? MetadataValue.Float64)?.value?.toFloat()
|
||||
fun String.strList(): List<String>? =
|
||||
(m[this] as? MetadataValue.ArrayVal)
|
||||
?.elements
|
||||
?.mapNotNull { (it as? MetadataValue.StringVal)?.value }
|
||||
|
||||
val arch = "general.architecture".str() ?: ARCH_LLAMA
|
||||
|
||||
// -------------- populate sections ----------------
|
||||
val basic = GgufMetadata.BasicInfo(
|
||||
uuid = "general.uuid".str(),
|
||||
name = "general.basename".str(),
|
||||
nameLabel = "general.name".str(),
|
||||
sizeLabel = "general.size_label".str()
|
||||
)
|
||||
|
||||
val author = GgufMetadata.AuthorInfo(
|
||||
organization = "general.organization".str(),
|
||||
author = "general.author".str(),
|
||||
doi = "general.doi".str(),
|
||||
url = "general.url".str(),
|
||||
repoUrl = "general.repo_url".str(),
|
||||
license = "general.license".str(),
|
||||
licenseLink = "general.license.link".str()
|
||||
).takeUnless {
|
||||
organization == null && author == null && doi == null &&
|
||||
url == null && repoUrl == null && license == null && licenseLink == null
|
||||
}
|
||||
|
||||
val additional = GgufMetadata.AdditionalInfo(
|
||||
type = "general.type".str(),
|
||||
description = "general.description".str(),
|
||||
tags = "general.tags".strList(),
|
||||
languages = "general.languages".strList()
|
||||
).takeUnless {
|
||||
type == null && description == null && tags == null && languages == null
|
||||
}
|
||||
|
||||
val architectureInfo = GgufMetadata.ArchitectureInfo(
|
||||
architecture = arch,
|
||||
fileType = "general.file_type".u32(),
|
||||
vocabSize = "$arch.vocab_size".u32(),
|
||||
finetune = "general.finetune".str(),
|
||||
quantizationVersion = "general.quantization_version".u32()
|
||||
).takeUnless { fileType == null && vocabSize == null && finetune == null && quantizationVersion == null }
|
||||
|
||||
val baseModels = buildList {
|
||||
val n = "general.base_model.count".u32() ?: 0
|
||||
for (i in 0 until n) {
|
||||
fun k(s: String) = "general.base_model.$i.$s"
|
||||
add(
|
||||
GgufMetadata.BaseModelInfo(
|
||||
name = k("name").str(),
|
||||
author = k("author").str(),
|
||||
version = k("version").str(),
|
||||
organization = k("organization").str(),
|
||||
url = k("url").str(),
|
||||
doi = k("doi").str(),
|
||||
uuid = k("uuid").str(),
|
||||
repoUrl = k("repo_url").str(),
|
||||
)
|
||||
)
|
||||
}
|
||||
}.takeIf { it.isNotEmpty() }
|
||||
|
||||
val tokenizer = GgufMetadata.TokenizerInfo(
|
||||
model = "tokenizer.ggml.model".str(),
|
||||
bosTokenId = "tokenizer.ggml.bos_token_id".u32(),
|
||||
eosTokenId = "tokenizer.ggml.eos_token_id".u32(),
|
||||
unknownTokenId = "tokenizer.ggml.unknown_token_id".u32(),
|
||||
paddingTokenId = "tokenizer.ggml.padding_token_id".u32(),
|
||||
addBosToken = "tokenizer.ggml.add_bos_token".bool(),
|
||||
addEosToken = "tokenizer.ggml.add_eos_token".bool(),
|
||||
chatTemplate = "tokenizer.chat_template".str()
|
||||
).takeUnless { model == null && bosTokenId == null && eosTokenId == null &&
|
||||
unknownTokenId == null && paddingTokenId == null &&
|
||||
addBosToken == null && addEosToken == null && chatTemplate == null
|
||||
}
|
||||
|
||||
val dimensions = GgufMetadata.DimensionsInfo(
|
||||
contextLength = "$arch.context_length".u32(),
|
||||
embeddingSize = "$arch.embedding_length".u32(),
|
||||
blockCount = "$arch.block_count".u32(),
|
||||
feedForwardSize = "$arch.feed_forward_length".u32()
|
||||
).takeUnless { contextLength == null && embeddingSize == null && blockCount == null && feedForwardSize == null }
|
||||
|
||||
val attention = GgufMetadata.AttentionInfo(
|
||||
headCount = "$arch.attention.head_count".u32(),
|
||||
headCountKv = "$arch.attention.head_count_kv".u32(),
|
||||
keyLength = "$arch.attention.key_length".u32(),
|
||||
valueLength = "$arch.attention.value_length".u32(),
|
||||
layerNormEpsilon = "$arch.attention.layer_norm_epsilon".f32(),
|
||||
layerNormRmsEpsilon = "$arch.attention.layer_norm_rms_epsilon".f32(),
|
||||
).takeUnless { headCount == null && headCountKv == null && keyLength == null && valueLength == null &&
|
||||
layerNormEpsilon == null && layerNormRmsEpsilon == null
|
||||
}
|
||||
|
||||
val rope = GgufMetadata.RopeInfo(
|
||||
frequencyBase = "$arch.rope.freq_base".f32(),
|
||||
dimensionCount = "$arch.rope.dimension_count".u32(),
|
||||
scalingType = "$arch.rope.scaling.type".str(),
|
||||
scalingFactor = "$arch.rope.scaling.factor".f32(),
|
||||
attnFactor = "$arch.rope.scaling.attn_factor".f32(),
|
||||
originalContextLength = "$arch.rope.scaling.original_context_length".u32(),
|
||||
finetuned = "$arch.rope.scaling.finetuned".bool()
|
||||
).takeUnless { frequencyBase == null && dimensionCount == null &&
|
||||
scalingType == null && scalingFactor == null && attnFactor == null &&
|
||||
originalContextLength == null && finetuned == null
|
||||
}
|
||||
|
||||
val experts = GgufMetadata.ExpertsInfo(
|
||||
count = "$arch.expert_count".u32(),
|
||||
usedCount = "$arch.expert_used_count".u32()
|
||||
).takeUnless { count == null && usedCount == null }
|
||||
|
||||
return GgufMetadata(
|
||||
version = version,
|
||||
tensorCount = tensorCnt,
|
||||
kvCount = kvCnt,
|
||||
basic = basic,
|
||||
author = author,
|
||||
additional = additional,
|
||||
architecture = architectureInfo,
|
||||
baseModels = baseModels,
|
||||
tokenizer = tokenizer,
|
||||
dimensions = dimensions,
|
||||
attention = attention,
|
||||
rope = rope,
|
||||
experts = experts
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively parses a metadata value of the given type from the input stream.
|
||||
* @param input The input stream positioned at the start of the value.
|
||||
* @param type The metadata value type to parse.
|
||||
*/
|
||||
private fun parseValue(input: InputStream, type: MetadataType): MetadataValue = when (type) {
|
||||
MetadataType.UINT8 -> {
|
||||
// 1-byte unsigned integer
|
||||
val byteVal = input.read()
|
||||
if (byteVal == -1) throw IOException("Unexpected EOF while reading uint8 value.")
|
||||
MetadataValue.UInt8(byteVal.toUByte())
|
||||
}
|
||||
MetadataType.INT8 -> {
|
||||
// 1-byte signed integer
|
||||
val byteVal = input.read()
|
||||
if (byteVal == -1) throw IOException("Unexpected EOF while reading int8 value.")
|
||||
MetadataValue.Int8(byteVal.toByte())
|
||||
}
|
||||
MetadataType.UINT16 -> {
|
||||
// 2-byte unsigned integer (little-endian)
|
||||
val bytes = ByteArray(2)
|
||||
if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading uint16 value.")
|
||||
// Combine two bytes (little-endian) into an unsigned 16-bit value
|
||||
val u16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF)
|
||||
MetadataValue.UInt16(u16.toUShort())
|
||||
}
|
||||
MetadataType.INT16 -> {
|
||||
// 2-byte signed integer (little-endian)
|
||||
val bytes = ByteArray(2)
|
||||
if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading int16 value.")
|
||||
// Combine to 16-bit and interpret as signed
|
||||
val i16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF)
|
||||
MetadataValue.Int16(i16.toShort())
|
||||
}
|
||||
MetadataType.UINT32 -> {
|
||||
// 4-byte unsigned integer (little-endian)
|
||||
val bytes = ByteArray(4)
|
||||
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading uint32 value.")
|
||||
// Combine four bytes into a 32-bit value (as Long to avoid overflow), then convert to UInt
|
||||
val u32 = (bytes[3].toLong() and 0xFFL shl 24) or
|
||||
(bytes[2].toLong() and 0xFFL shl 16) or
|
||||
(bytes[1].toLong() and 0xFFL shl 8) or
|
||||
(bytes[0].toLong() and 0xFFL)
|
||||
MetadataValue.UInt32(u32.toUInt())
|
||||
}
|
||||
MetadataType.INT32 -> {
|
||||
// 4-byte signed integer (little-endian)
|
||||
val bytes = ByteArray(4)
|
||||
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading int32 value.")
|
||||
// Combine four bytes into a 32-bit signed int
|
||||
val i32 = (bytes[3].toInt() and 0xFF shl 24) or
|
||||
(bytes[2].toInt() and 0xFF shl 16) or
|
||||
(bytes[1].toInt() and 0xFF shl 8) or
|
||||
(bytes[0].toInt() and 0xFF)
|
||||
MetadataValue.Int32(i32)
|
||||
}
|
||||
MetadataType.FLOAT32 -> {
|
||||
// 4-byte IEEE 754 float (little-endian)
|
||||
val bytes = ByteArray(4)
|
||||
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading float32 value.")
|
||||
// Assemble 4 bytes into a 32-bit int bit-pattern, then convert to Float
|
||||
val bits = (bytes[3].toInt() and 0xFF shl 24) or
|
||||
(bytes[2].toInt() and 0xFF shl 16) or
|
||||
(bytes[1].toInt() and 0xFF shl 8) or
|
||||
(bytes[0].toInt() and 0xFF)
|
||||
val floatVal = Float.fromBits(bits)
|
||||
MetadataValue.Float32(floatVal)
|
||||
}
|
||||
MetadataType.BOOL -> {
|
||||
// 1-byte boolean (0 = false, 1 = true)
|
||||
val byteVal = input.read()
|
||||
if (byteVal == -1) throw IOException("Unexpected EOF while reading boolean value.")
|
||||
if (byteVal != 0 && byteVal != 1) {
|
||||
throw IOException("Invalid boolean value: $byteVal (must be 0 or 1).")
|
||||
}
|
||||
MetadataValue.Bool(byteVal != 0)
|
||||
}
|
||||
MetadataType.STRING -> {
|
||||
// UTF-8 string (length-prefixed with 8-byte length)
|
||||
val str = readString(input)
|
||||
MetadataValue.StringVal(str)
|
||||
}
|
||||
MetadataType.ARRAY -> {
|
||||
val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
|
||||
val len = readLittleLong(input)
|
||||
val count = len.toInt()
|
||||
|
||||
if (arraySummariseThreshold >= 0 && count > arraySummariseThreshold) {
|
||||
// fast‑forward without allocation
|
||||
repeat(count) { skipValue(input, elemType) }
|
||||
MetadataValue.StringVal("Array($elemType, $count items) /* summarised */")
|
||||
} else {
|
||||
val list = ArrayList<MetadataValue>(count)
|
||||
repeat(count) { list += parseValue(input, elemType) }
|
||||
MetadataValue.ArrayVal(elemType, list)
|
||||
}
|
||||
}
|
||||
MetadataType.UINT64 -> {
|
||||
// 8-byte unsigned integer (little-endian)
|
||||
val bytes = ByteArray(8)
|
||||
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading uint64 value.")
|
||||
// Combine 8 bytes into an unsigned 64-bit (ULong). Use ULong for full 0 to 2^64-1 range.
|
||||
val u64 = (bytes[7].toULong() and 0xFFuL shl 56) or
|
||||
(bytes[6].toULong() and 0xFFuL shl 48) or
|
||||
(bytes[5].toULong() and 0xFFuL shl 40) or
|
||||
(bytes[4].toULong() and 0xFFuL shl 32) or
|
||||
(bytes[3].toULong() and 0xFFuL shl 24) or
|
||||
(bytes[2].toULong() and 0xFFuL shl 16) or
|
||||
(bytes[1].toULong() and 0xFFuL shl 8) or
|
||||
(bytes[0].toULong() and 0xFFuL)
|
||||
MetadataValue.UInt64(u64)
|
||||
}
|
||||
MetadataType.INT64 -> {
|
||||
// 8-byte signed integer (little-endian)
|
||||
val bytes = ByteArray(8)
|
||||
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading int64 value.")
|
||||
// Combine 8 bytes into a signed 64-bit value (Long)
|
||||
val i64 = (bytes[7].toLong() and 0xFFL shl 56) or
|
||||
(bytes[6].toLong() and 0xFFL shl 48) or
|
||||
(bytes[5].toLong() and 0xFFL shl 40) or
|
||||
(bytes[4].toLong() and 0xFFL shl 32) or
|
||||
(bytes[3].toLong() and 0xFFL shl 24) or
|
||||
(bytes[2].toLong() and 0xFFL shl 16) or
|
||||
(bytes[1].toLong() and 0xFFL shl 8) or
|
||||
(bytes[0].toLong() and 0xFFL)
|
||||
MetadataValue.Int64(i64)
|
||||
}
|
||||
MetadataType.FLOAT64 -> {
|
||||
// 8-byte IEEE 754 double (little-endian)
|
||||
val bytes = ByteArray(8)
|
||||
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading float64 value.")
|
||||
// Assemble 8 bytes into a 64-bit bit-pattern, then convert to Double
|
||||
val bits = (bytes[7].toLong() and 0xFFL shl 56) or
|
||||
(bytes[6].toLong() and 0xFFL shl 48) or
|
||||
(bytes[5].toLong() and 0xFFL shl 40) or
|
||||
(bytes[4].toLong() and 0xFFL shl 32) or
|
||||
(bytes[3].toLong() and 0xFFL shl 24) or
|
||||
(bytes[2].toLong() and 0xFFL shl 16) or
|
||||
(bytes[1].toLong() and 0xFFL shl 8) or
|
||||
(bytes[0].toLong() and 0xFFL)
|
||||
val doubleVal = Double.fromBits(bits)
|
||||
MetadataValue.Float64(doubleVal)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private fun <T> T?.takeUnless(check: T.() -> Boolean): T? =
|
||||
this?.takeIf { !it.check() }
|
||||
|
||||
/** Helper: Skip a value in the stream without storing it (still maintains pointer). */
|
||||
private fun skipValue(input: InputStream, type: MetadataType) {
|
||||
when (type) {
|
||||
MetadataType.UINT8, MetadataType.INT8, MetadataType.BOOL -> input.skipFully(1)
|
||||
MetadataType.UINT16, MetadataType.INT16 -> input.skipFully(2)
|
||||
MetadataType.UINT32, MetadataType.INT32, MetadataType.FLOAT32 -> input.skipFully(4)
|
||||
MetadataType.UINT64, MetadataType.INT64, MetadataType.FLOAT64 -> input.skipFully(8)
|
||||
MetadataType.STRING -> {
|
||||
val len = readLittleLong(input); input.skipFully(len)
|
||||
}
|
||||
MetadataType.ARRAY -> {
|
||||
val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
|
||||
val len = readLittleLong(input)
|
||||
repeat(len.toInt()) { skipValue(input, elemType) } // recursive skip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Helper: Read an 8-byte little-endian unsigned value and return it as a signed Long (assuming it fits in 63 bits). */
|
||||
private fun readLittleLong(input: InputStream): Long {
|
||||
val bytes = ByteArray(8)
|
||||
input.readFully(bytes)
|
||||
|
||||
// Combine 8 bytes into a 64-bit value (Little Endian).
|
||||
// Note: If the value exceeds Long.MAX_VALUE (bit 63 is 1), this will produce a negative Long (two's complement).
|
||||
// In our context (lengths/counts), such extremely large values are not expected.
|
||||
return (bytes[7].toLong() and 0xFFL shl 56) or
|
||||
(bytes[6].toLong() and 0xFFL shl 48) or
|
||||
(bytes[5].toLong() and 0xFFL shl 40) or
|
||||
(bytes[4].toLong() and 0xFFL shl 32) or
|
||||
(bytes[3].toLong() and 0xFFL shl 24) or
|
||||
(bytes[2].toLong() and 0xFFL shl 16) or
|
||||
(bytes[1].toLong() and 0xFFL shl 8) or
|
||||
(bytes[0].toLong() and 0xFFL)
|
||||
}
|
||||
|
||||
/** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */
|
||||
private fun readString(input: InputStream): String {
|
||||
// Read 8-byte little-endian length (number of bytes in the string).
|
||||
val len = readLittleLong(input)
|
||||
if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len")
|
||||
|
||||
// Read the UTF-8 bytes of the given length.
|
||||
val buf = ByteArray(len.toInt())
|
||||
if (buf.isNotEmpty()) input.readFully(buf)
|
||||
return String(buf, Charsets.UTF_8)
|
||||
}
|
||||
|
||||
/** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */
|
||||
private fun littleEndianBytesToInt(bytes: ByteArray): Int {
|
||||
// Note: assumes bytes length is 4.
|
||||
return (bytes[3].toInt() and 0xFF shl 24) or
|
||||
(bytes[2].toInt() and 0xFF shl 16) or
|
||||
(bytes[1].toInt() and 0xFF shl 8) or
|
||||
(bytes[0].toInt() and 0xFF)
|
||||
}
|
||||
|
||||
/**
|
||||
* Robust skip that works the same on JDK 11 and Android’s desugared runtime.
|
||||
*
|
||||
* @param n Number of bytes to advance in the stream.
|
||||
* @throws IOException on premature EOF.
|
||||
*/
|
||||
private fun InputStream.skipFully(n: Long) {
|
||||
var remaining = n
|
||||
val scratch = ByteArray(8192) // read‑and‑toss buffer
|
||||
while (remaining > 0) {
|
||||
val skipped = skip(remaining)
|
||||
when {
|
||||
skipped > 0 -> remaining -= skipped // normal fast path
|
||||
skipped == 0L -> {
|
||||
// fallback: read and discard
|
||||
val read = read(scratch, 0, minOf(remaining, scratch.size.toLong()).toInt())
|
||||
if (read == -1) throw IOException("EOF while skipping $n bytes")
|
||||
remaining -= read
|
||||
}
|
||||
else -> throw IOException("Skip returned negative value")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extension that keeps reading until the requested number of bytes are filled.
|
||||
* Falls back to `read()` when `skip()` returns 0, which happens on some Android
|
||||
* streams.
|
||||
*
|
||||
* @param buf Destination buffer.
|
||||
* @param len Number of bytes to fill (defaults to `buf.size`).
|
||||
* @throws IOException on premature EOF.
|
||||
*/
|
||||
private fun InputStream.readFully(buf: ByteArray, len: Int = buf.size) {
|
||||
var off = 0
|
||||
while (off < len) {
|
||||
val n = read(buf, off, len - off)
|
||||
if (n == -1) throw IOException("EOF after $off of $len bytes")
|
||||
off += n
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Read EXACTLY `n` bytes or throw – never returns a partially‑filled array.
|
||||
* This is used for small fixed‑length reads (e.g. 4‑byte type codes).
|
||||
*
|
||||
* @throws IOException on premature EOF.
|
||||
*/
|
||||
private fun InputStream.readNBytesExact(n: Int): ByteArray {
|
||||
val buf = ByteArray(n)
|
||||
if (read(buf) != n) throw IOException("Unexpected EOF")
|
||||
return buf
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,13 @@
|
|||
package android.llama.cpp.internal
|
||||
|
||||
import android.content.Context
|
||||
import android.llama.cpp.TierDetection
|
||||
|
||||
/**
|
||||
* Internal factory to create [InferenceEngine] and [TierDetection]
|
||||
*/
|
||||
internal object InferenceEngineFactory {
|
||||
fun createInstance(context: Context) = InferenceEngineLoader.createInstance(context)
|
||||
|
||||
fun getTierDetection(context: Context): TierDetection = TierDetectionImpl(context)
|
||||
}
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
package android.llama.cpp
|
||||
package android.llama.cpp.internal
|
||||
|
||||
import android.llama.cpp.InferenceEngine.State
|
||||
import android.llama.cpp.InferenceEngine
|
||||
import android.llama.cpp.LLamaTier
|
||||
import android.util.Log
|
||||
import kotlinx.coroutines.CancellationException
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
|
|
@ -85,8 +86,9 @@ internal class InferenceEngineImpl private constructor(
|
|||
private external fun unload()
|
||||
private external fun shutdown()
|
||||
|
||||
private val _state = MutableStateFlow<State>(State.Uninitialized)
|
||||
override val state: StateFlow<State> = _state
|
||||
private val _state =
|
||||
MutableStateFlow<InferenceEngine.State>(InferenceEngine.State.Uninitialized)
|
||||
override val state: StateFlow<InferenceEngine.State> = _state
|
||||
|
||||
private var _readyForSystemPrompt = false
|
||||
|
||||
|
|
@ -100,15 +102,15 @@ internal class InferenceEngineImpl private constructor(
|
|||
init {
|
||||
llamaScope.launch {
|
||||
try {
|
||||
check(_state.value is State.Uninitialized) {
|
||||
check(_state.value is InferenceEngine.State.Uninitialized) {
|
||||
"Cannot load native library in ${_state.value.javaClass.simpleName}!"
|
||||
}
|
||||
_state.value = State.Initializing
|
||||
_state.value = InferenceEngine.State.Initializing
|
||||
Log.i(TAG, "Loading native library for $tier")
|
||||
|
||||
System.loadLibrary(tier.libraryName)
|
||||
init()
|
||||
_state.value = State.Initialized
|
||||
_state.value = InferenceEngine.State.Initialized
|
||||
Log.i(TAG, "Native library loaded! System info: \n${systemInfo()}")
|
||||
|
||||
} catch (e: Exception) {
|
||||
|
|
@ -123,7 +125,7 @@ internal class InferenceEngineImpl private constructor(
|
|||
*/
|
||||
override suspend fun loadModel(pathToModel: String) =
|
||||
withContext(llamaDispatcher) {
|
||||
check(_state.value is State.Initialized) {
|
||||
check(_state.value is InferenceEngine.State.Initialized) {
|
||||
"Cannot load model in ${_state.value.javaClass.simpleName}!"
|
||||
}
|
||||
File(pathToModel).let {
|
||||
|
|
@ -133,7 +135,7 @@ internal class InferenceEngineImpl private constructor(
|
|||
|
||||
Log.i(TAG, "Loading model... \n$pathToModel")
|
||||
_readyForSystemPrompt = false
|
||||
_state.value = State.LoadingModel
|
||||
_state.value = InferenceEngine.State.LoadingModel
|
||||
load(pathToModel).let { result ->
|
||||
if (result != 0) throw IllegalStateException("Failed to Load model: $result")
|
||||
}
|
||||
|
|
@ -142,7 +144,7 @@ internal class InferenceEngineImpl private constructor(
|
|||
}
|
||||
Log.i(TAG, "Model loaded!")
|
||||
_readyForSystemPrompt = true
|
||||
_state.value = State.ModelReady
|
||||
_state.value = InferenceEngine.State.ModelReady
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -154,40 +156,40 @@ internal class InferenceEngineImpl private constructor(
|
|||
withContext(llamaDispatcher) {
|
||||
require(prompt.isNotBlank()) { "Cannot process empty system prompt!" }
|
||||
check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" }
|
||||
check(_state.value is State.ModelReady) {
|
||||
check(_state.value is InferenceEngine.State.ModelReady) {
|
||||
"Cannot process system prompt in ${_state.value.javaClass.simpleName}!"
|
||||
}
|
||||
|
||||
Log.i(TAG, "Sending system prompt...")
|
||||
_readyForSystemPrompt = false
|
||||
_state.value = State.ProcessingSystemPrompt
|
||||
_state.value = InferenceEngine.State.ProcessingSystemPrompt
|
||||
processSystemPrompt(prompt).let { result ->
|
||||
if (result != 0) {
|
||||
val errorMessage = "Failed to process system prompt: $result"
|
||||
_state.value = State.Error(errorMessage)
|
||||
_state.value = InferenceEngine.State.Error(errorMessage)
|
||||
throw IllegalStateException(errorMessage)
|
||||
}
|
||||
}
|
||||
Log.i(TAG, "System prompt processed! Awaiting user prompt...")
|
||||
_state.value = State.ModelReady
|
||||
_state.value = InferenceEngine.State.ModelReady
|
||||
}
|
||||
|
||||
/**
|
||||
* Send plain text user prompt to LLM, which starts generating tokens in a [Flow]
|
||||
* Send plain text user prompt to LLM, which starts generating tokens in a [kotlinx.coroutines.flow.Flow]
|
||||
*/
|
||||
override fun sendUserPrompt(
|
||||
message: String,
|
||||
predictLength: Int,
|
||||
): Flow<String> = flow {
|
||||
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
|
||||
check(_state.value is State.ModelReady) {
|
||||
check(_state.value is InferenceEngine.State.ModelReady) {
|
||||
"User prompt discarded due to: ${_state.value.javaClass.simpleName}"
|
||||
}
|
||||
|
||||
try {
|
||||
Log.i(TAG, "Sending user prompt...")
|
||||
_readyForSystemPrompt = false
|
||||
_state.value = State.ProcessingUserPrompt
|
||||
_state.value = InferenceEngine.State.ProcessingUserPrompt
|
||||
|
||||
processUserPrompt(message, predictLength).let { result ->
|
||||
if (result != 0) {
|
||||
|
|
@ -197,21 +199,21 @@ internal class InferenceEngineImpl private constructor(
|
|||
}
|
||||
|
||||
Log.i(TAG, "User prompt processed. Generating assistant prompt...")
|
||||
_state.value = State.Generating
|
||||
_state.value = InferenceEngine.State.Generating
|
||||
while (true) {
|
||||
generateNextToken()?.let { utf8token ->
|
||||
if (utf8token.isNotEmpty()) emit(utf8token)
|
||||
} ?: break
|
||||
}
|
||||
Log.i(TAG, "Assistant generation complete. Awaiting user prompt...")
|
||||
_state.value = State.ModelReady
|
||||
_state.value = InferenceEngine.State.ModelReady
|
||||
} catch (e: CancellationException) {
|
||||
Log.i(TAG, "Generation cancelled by user.")
|
||||
_state.value = State.ModelReady
|
||||
_state.value = InferenceEngine.State.ModelReady
|
||||
throw e
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error during generation!", e)
|
||||
_state.value = State.Error(e.message ?: "Unknown error")
|
||||
_state.value = InferenceEngine.State.Error(e.message ?: "Unknown error")
|
||||
throw e
|
||||
}
|
||||
}.flowOn(llamaDispatcher)
|
||||
|
|
@ -221,14 +223,14 @@ internal class InferenceEngineImpl private constructor(
|
|||
*/
|
||||
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String =
|
||||
withContext(llamaDispatcher) {
|
||||
check(_state.value is State.ModelReady) {
|
||||
check(_state.value is InferenceEngine.State.ModelReady) {
|
||||
"Benchmark request discarded due to: $state"
|
||||
}
|
||||
Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)")
|
||||
_readyForSystemPrompt = false // Just to be safe
|
||||
_state.value = State.Benchmarking
|
||||
_state.value = InferenceEngine.State.Benchmarking
|
||||
benchModel(pp, tg, pl, nr).also {
|
||||
_state.value = State.ModelReady
|
||||
_state.value = InferenceEngine.State.ModelReady
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -237,18 +239,19 @@ internal class InferenceEngineImpl private constructor(
|
|||
*/
|
||||
override suspend fun unloadModel() =
|
||||
withContext(llamaDispatcher) {
|
||||
when(val state = _state.value) {
|
||||
is State.ModelReady, is State.Error -> {
|
||||
when (val state = _state.value) {
|
||||
is InferenceEngine.State.ModelReady, is InferenceEngine.State.Error -> {
|
||||
Log.i(TAG, "Unloading model and free resources...")
|
||||
_readyForSystemPrompt = false
|
||||
_state.value = State.UnloadingModel
|
||||
_state.value = InferenceEngine.State.UnloadingModel
|
||||
|
||||
unload()
|
||||
|
||||
_state.value = State.Initialized
|
||||
_state.value = InferenceEngine.State.Initialized
|
||||
Log.i(TAG, "Model unloaded!")
|
||||
Unit
|
||||
}
|
||||
|
||||
else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}")
|
||||
}
|
||||
}
|
||||
|
|
@ -260,8 +263,8 @@ internal class InferenceEngineImpl private constructor(
|
|||
_readyForSystemPrompt = false
|
||||
llamaScope.cancel()
|
||||
when(_state.value) {
|
||||
is State.Uninitialized -> {}
|
||||
is State.Initialized -> shutdown()
|
||||
is InferenceEngine.State.Uninitialized -> {}
|
||||
is InferenceEngine.State.Initialized -> shutdown()
|
||||
else -> { unload(); shutdown() }
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,165 @@
|
|||
package android.llama.cpp.internal
|
||||
|
||||
import android.content.Context
|
||||
import android.llama.cpp.InferenceEngine
|
||||
import android.llama.cpp.LLamaTier
|
||||
import android.util.Log
|
||||
import androidx.datastore.core.DataStore
|
||||
import androidx.datastore.preferences.core.Preferences
|
||||
import androidx.datastore.preferences.core.edit
|
||||
import androidx.datastore.preferences.core.intPreferencesKey
|
||||
import androidx.datastore.preferences.preferencesDataStore
|
||||
import kotlinx.coroutines.flow.first
|
||||
import kotlinx.coroutines.runBlocking
|
||||
|
||||
/**
|
||||
* Internal [android.llama.cpp.InferenceEngine] loader implementation
|
||||
*/
|
||||
internal object InferenceEngineLoader {
|
||||
private val TAG = InferenceEngineLoader::class.simpleName
|
||||
|
||||
// CPU feature detection preferences
|
||||
private const val DATASTORE_CPU_DETECTION = "llama_cpu_detection"
|
||||
private val Context.llamaTierDataStore: DataStore<Preferences>
|
||||
by preferencesDataStore(name = DATASTORE_CPU_DETECTION)
|
||||
|
||||
private val DETECTION_VERSION = intPreferencesKey("detection_version")
|
||||
private val DETECTED_TIER = intPreferencesKey("detected_tier")
|
||||
|
||||
// Constants
|
||||
private const val DATASTORE_VERSION = 1
|
||||
|
||||
@JvmStatic
|
||||
private external fun getOptimalTier(): Int
|
||||
|
||||
@JvmStatic
|
||||
private external fun getCpuFeaturesString(): String
|
||||
|
||||
private var _cachedInstance: InferenceEngineImpl? = null
|
||||
private var _detectedTier: LLamaTier? = null
|
||||
val detectedTier: LLamaTier? get() = _detectedTier
|
||||
|
||||
/**
|
||||
* Factory method to get a configured [InferenceEngineImpl] instance.
|
||||
* Handles tier detection, caching, and library loading automatically.
|
||||
*/
|
||||
@Synchronized
|
||||
fun createInstance(context: Context): InferenceEngine? {
|
||||
// Return cached instance if available
|
||||
_cachedInstance?.let { return it }
|
||||
|
||||
return runBlocking {
|
||||
try {
|
||||
// Obtain the optimal tier from cache if available
|
||||
val tier = getOrDetectOptimalTier(context) ?: run {
|
||||
Log.e(TAG, "Failed to determine optimal tier")
|
||||
return@runBlocking null
|
||||
}
|
||||
_detectedTier = tier
|
||||
Log.i(TAG, "Using tier: ${tier.name} (${tier.description})")
|
||||
|
||||
// Create and cache the inference engine instance
|
||||
val instance = InferenceEngineImpl.createWithTier(tier) ?: run {
|
||||
Log.e(TAG, "Failed to instantiate InferenceEngineImpl")
|
||||
return@runBlocking null
|
||||
}
|
||||
_cachedInstance = instance
|
||||
Log.i(TAG, "Successfully created InferenceEngineImpl instance with ${tier.name}")
|
||||
|
||||
instance
|
||||
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error creating InferenceEngineImpl instance", e)
|
||||
null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear cached detection results (for testing/debugging)
|
||||
*/
|
||||
fun clearCache(context: Context) {
|
||||
runBlocking { context.llamaTierDataStore.edit { it.clear() } }
|
||||
_cachedInstance = null
|
||||
_detectedTier = null
|
||||
Log.i(TAG, "Cleared detection results and cached instance")
|
||||
}
|
||||
|
||||
/**
|
||||
* Get optimal tier from cache or detect it fresh
|
||||
*/
|
||||
private suspend fun getOrDetectOptimalTier(context: Context): LLamaTier? {
|
||||
val preferences = context.llamaTierDataStore.data.first()
|
||||
|
||||
// Check if we have a cached result with the current detection version
|
||||
val cachedVersion = preferences[DETECTION_VERSION] ?: -1
|
||||
val cachedTierValue = preferences[DETECTED_TIER] ?: -1
|
||||
if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) {
|
||||
val cachedTier = LLamaTier.Companion.fromRawValue(cachedTierValue)
|
||||
if (cachedTier != null) {
|
||||
Log.i(TAG, "Using cached tier detection: ${cachedTier.name}")
|
||||
return cachedTier
|
||||
}
|
||||
}
|
||||
|
||||
// No valid cache, detect fresh
|
||||
Log.i(TAG, "Performing fresh tier detection")
|
||||
return detectAndCacheOptimalTier(context)
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect optimal tier and save to cache
|
||||
*/
|
||||
private suspend fun detectAndCacheOptimalTier(context: Context): LLamaTier? {
|
||||
try {
|
||||
// Load CPU detection library
|
||||
System.loadLibrary("llama_cpu_detector")
|
||||
Log.i(TAG, "CPU feature detector loaded successfully")
|
||||
|
||||
// Detect optimal tier
|
||||
val tierValue = getOptimalTier()
|
||||
val features = getCpuFeaturesString()
|
||||
Log.i(TAG, "Raw tier $tierValue w/ CPU features: $features")
|
||||
|
||||
// Convert to enum and validate
|
||||
val tier = LLamaTier.Companion.fromRawValue(tierValue) ?: run {
|
||||
Log.w(TAG, "Invalid tier value $tierValue")
|
||||
return null
|
||||
}
|
||||
|
||||
// Ensure we don't exceed maximum supported tier
|
||||
val finalTier = if (tier.rawValue > LLamaTier.Companion.maxSupportedTier.rawValue) {
|
||||
Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${LLamaTier.Companion.maxSupportedTier.name}")
|
||||
LLamaTier.Companion.maxSupportedTier
|
||||
} else {
|
||||
tier
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
context.llamaTierDataStore.edit {
|
||||
it[DETECTED_TIER] = finalTier.rawValue
|
||||
it[DETECTION_VERSION] = DATASTORE_VERSION
|
||||
}
|
||||
|
||||
Log.i(TAG, "Detected and cached optimal tier: ${finalTier.name}")
|
||||
return finalTier
|
||||
|
||||
} catch (e: UnsatisfiedLinkError) {
|
||||
Log.e(TAG, "Failed to load CPU detection library", e)
|
||||
|
||||
// Fallback to T0 and cache it
|
||||
val fallbackTier = LLamaTier.T0
|
||||
context.llamaTierDataStore.edit {
|
||||
it[DETECTED_TIER] = fallbackTier.rawValue
|
||||
it[DETECTION_VERSION] = DATASTORE_VERSION
|
||||
}
|
||||
|
||||
Log.i(TAG, "Using fallback tier: ${fallbackTier.name}")
|
||||
return fallbackTier
|
||||
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Unexpected error during tier detection", e)
|
||||
return null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
package android.llama.cpp.internal
|
||||
|
||||
import android.content.Context
|
||||
import android.llama.cpp.LLamaTier
|
||||
import android.llama.cpp.TierDetection
|
||||
|
||||
/**
|
||||
* Internal tier detection implementation
|
||||
*/
|
||||
internal class TierDetectionImpl(private val context: Context) : TierDetection {
|
||||
override val detectedTier: LLamaTier?
|
||||
get() = InferenceEngineLoader.detectedTier
|
||||
|
||||
override fun clearCache() = InferenceEngineLoader.clearCache(context)
|
||||
}
|
||||
|
|
@ -0,0 +1,568 @@
|
|||
package android.llama.cpp.internal.gguf
|
||||
|
||||
import android.llama.cpp.gguf.GgufMetadata
|
||||
import android.llama.cpp.gguf.GgufMetadataReader
|
||||
import java.io.File
|
||||
import java.io.IOException
|
||||
import java.io.InputStream
|
||||
|
||||
|
||||
/**
|
||||
* Utility class to read GGUF model files and extract metadata key-value pairs.
|
||||
* This parser reads the header and metadata of a GGUF v3 file (little-endian) and skips tensor data.
|
||||
*/
|
||||
internal class GgufMetadataReaderImpl(
|
||||
private val skipKeys: Set<String>,
|
||||
private val arraySummariseThreshold: Int,
|
||||
) : GgufMetadataReader {
|
||||
companion object {
|
||||
private const val ARCH_LLAMA = "llama"
|
||||
}
|
||||
|
||||
/** Enum corresponding to GGUF metadata value types (for convenience and array element typing). */
|
||||
enum class MetadataType(val code: Int) {
|
||||
UINT8(0), INT8(1), UINT16(2), INT16(3),
|
||||
UINT32(4), INT32(5), FLOAT32(6), BOOL(7),
|
||||
STRING(8), ARRAY(9), UINT64(10), INT64(11), FLOAT64(12);
|
||||
companion object {
|
||||
private val codeMap = values().associateBy(MetadataType::code)
|
||||
fun fromCode(code: Int): MetadataType = codeMap[code]
|
||||
?: throw IOException("Unknown metadata value type code: $code")
|
||||
}
|
||||
}
|
||||
|
||||
/** Sealed class hierarchy for metadata values, providing type-safe representations for each GGUF metadata type. */
|
||||
sealed class MetadataValue {
|
||||
data class UInt8(val value: UByte) : MetadataValue() // 0: 8-bit unsigned int
|
||||
data class Int8(val value: Byte) : MetadataValue() // 1: 8-bit signed int
|
||||
data class UInt16(val value: UShort) : MetadataValue() // 2: 16-bit unsigned int (little-endian)
|
||||
data class Int16(val value: Short) : MetadataValue() // 3: 16-bit signed int (little-endian)
|
||||
data class UInt32(val value: UInt) : MetadataValue() // 4: 32-bit unsigned int (little-endian)
|
||||
data class Int32(val value: Int) : MetadataValue() // 5: 32-bit signed int (little-endian)
|
||||
data class Float32(val value: Float) : MetadataValue() // 6: 32-bit IEEE754 float
|
||||
data class Bool(val value: Boolean) : MetadataValue() // 7: Boolean (1-byte, 0=false, 1=true)
|
||||
data class StringVal(val value: String) : MetadataValue() // 8: UTF-8 string (length-prefixed)
|
||||
data class ArrayVal(val elementType: MetadataType, val elements: List<MetadataValue>) : MetadataValue()
|
||||
data class UInt64(val value: ULong) : MetadataValue() // 10: 64-bit unsigned int (little-endian)
|
||||
data class Int64(val value: Long) : MetadataValue() // 11: 64-bit signed int (little-endian)
|
||||
data class Float64(val value: Double) : MetadataValue() // 12: 64-bit IEEE754 double
|
||||
}
|
||||
|
||||
/* Convert MetadataValue to plain Kotlin primitives for allMetadata map */
|
||||
private fun MetadataValue.toPrimitive(): Any = when (this) {
|
||||
is MetadataValue.UInt8 -> value
|
||||
is MetadataValue.Int8 -> value
|
||||
is MetadataValue.UInt16 -> value
|
||||
is MetadataValue.Int16 -> value
|
||||
is MetadataValue.UInt32 -> value
|
||||
is MetadataValue.Int32 -> value
|
||||
is MetadataValue.Float32 -> value
|
||||
is MetadataValue.Bool -> value
|
||||
is MetadataValue.StringVal -> value
|
||||
is MetadataValue.UInt64 -> value
|
||||
is MetadataValue.Int64 -> value
|
||||
is MetadataValue.Float64 -> value
|
||||
is MetadataValue.ArrayVal -> elements.map { it.toPrimitive() }
|
||||
}
|
||||
|
||||
/**
|
||||
* High‑level entry point: parses a `.gguf` file on disk and returns the fully
|
||||
* populated [GgufMetadata] tree.
|
||||
*
|
||||
* Steps performed internally:
|
||||
* 1. Reads and validates the 8‑byte header (`"GGUF"` magic + version).
|
||||
* 2. Streams through the key‑value section, skipping large blobs if the key
|
||||
* appears in [skipKeys] or if an array exceeds [arraySummariseThreshold].
|
||||
* 3. Converts the resulting raw map into strongly‑typed sub‑structures
|
||||
* (basic info, tokenizer, rope, etc.).
|
||||
*
|
||||
* The method is STREAMING‑ONLY: tensors are never mapped or loaded into
|
||||
* memory, so even multi‑GB model files can be processed in < 50 ms.
|
||||
*
|
||||
* @param path Absolute or relative filesystem path to a `.gguf` file.
|
||||
* @return A [GgufMetadata] instance containing all recognised metadata plus
|
||||
* an `allMetadata` map with any keys that were not given a dedicated
|
||||
* field.
|
||||
* @throws IOException if the file is not GGUF, the version is unsupported,
|
||||
* or the metadata block is truncated / corrupt.
|
||||
*/
|
||||
override suspend fun readStructuredMetadata(path: String): GgufMetadata {
|
||||
File(path).inputStream().buffered().use { input ->
|
||||
// ── 1. header ──────────────────────────────────────────────────────────
|
||||
// throws on mismatch
|
||||
val version = ensureMagicAndVersion(input)
|
||||
val tensorCount = readLittleLong(input)
|
||||
val kvCount = readLittleLong(input)
|
||||
|
||||
// ── 2. metadata map (reuse our raw parser, but we need access to the stream) ──
|
||||
val meta = readMetaMap(input, kvCount) // <String, MetadataValue>
|
||||
|
||||
// ── 3. build structured object ────────────────────────────────────────
|
||||
return buildStructured(meta, version, tensorCount, kvCount)
|
||||
}
|
||||
}
|
||||
|
||||
/** Reads the 4‑byte magic + 4‑byte version; throws if magic ≠ "GGUF". */
|
||||
private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion {
|
||||
val magic = ByteArray(4)
|
||||
if (input.read(magic) != 4) throw IOException("File too short (no magic)")
|
||||
if (!magic.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46))) // "GGUF"
|
||||
throw IOException("Not a GGUF file (bad magic)")
|
||||
return GgufMetadata.GgufVersion.fromCode(readLEUInt32(input))
|
||||
}
|
||||
|
||||
/**
|
||||
* Read an unsigned 32‑bit little‑endian integer.
|
||||
*
|
||||
* @throws IOException if fewer than four bytes are available.
|
||||
*/
|
||||
private fun readLEUInt32(input: InputStream): Int {
|
||||
val b0 = input.read(); val b1 = input.read(); val b2 = input.read(); val b3 = input.read()
|
||||
if (b3 == -1) throw IOException("Unexpected EOF while reading UInt32")
|
||||
return (b3 and 0xFF shl 24) or
|
||||
(b2 and 0xFF shl 16) or
|
||||
(b1 and 0xFF shl 8) or
|
||||
(b0 and 0xFF)
|
||||
}
|
||||
|
||||
/**
|
||||
* Low‑level helper that reads the entire “key-value” section from the current
|
||||
* stream position.
|
||||
*
|
||||
* @param input Open stream positioned JUST AFTER the header.
|
||||
* @param kvCnt Number of key‑value pairs (taken from the header).
|
||||
* @return Mutable map with one [MetadataValue] for every key that is NOT skipped.
|
||||
*
|
||||
* The function honours [skipKeys] and [arraySummariseThreshold] by invoking
|
||||
* [skipValue] or [parseValue] accordingly.
|
||||
*/
|
||||
private fun readMetaMap(input: InputStream, kvCnt: Long): Map<String, MetadataValue> {
|
||||
val map = mutableMapOf<String, MetadataValue>()
|
||||
repeat(kvCnt.toInt()) {
|
||||
val key = readString(input)
|
||||
val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
|
||||
if (key in skipKeys) {
|
||||
skipValue(input, valueT)
|
||||
} else {
|
||||
map[key] = parseValue(input, valueT)
|
||||
}
|
||||
}
|
||||
return map
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a flat [Map]<[String], [MetadataValue]> into the strongly‑typed
|
||||
* [GgufMetadata] tree used by the rest of the app.
|
||||
*
|
||||
* Only the keys listed in the spec are copied into dedicated data classes;
|
||||
* everything else is preserved in `GgufMetadata.allMetadata`.
|
||||
*
|
||||
* @param m Raw key/value map.
|
||||
* @param version GGUF file‑format version (enum).
|
||||
* @param tensorCnt Number of tensors (from the header).
|
||||
* @param kvCnt Total metadata pair count (from the header).
|
||||
*/
|
||||
private fun buildStructured(
|
||||
m: Map<String, MetadataValue>,
|
||||
version: GgufMetadata.GgufVersion,
|
||||
tensorCnt: Long,
|
||||
kvCnt: Long
|
||||
): GgufMetadata {
|
||||
// ---------- helpers ----------
|
||||
fun String.str() = (m[this] as? MetadataValue.StringVal)?.value
|
||||
fun String.bool() = (m[this] as? MetadataValue.Bool)?.value
|
||||
fun String.i32() = (m[this] as? MetadataValue.Int32)?.value
|
||||
fun String.u32() = (m[this] as? MetadataValue.UInt32)?.value?.toInt()
|
||||
fun String.f32() = (m[this] as? MetadataValue.Float32)?.value
|
||||
fun String.f64() = (m[this] as? MetadataValue.Float64)?.value?.toFloat()
|
||||
fun String.strList(): List<String>? =
|
||||
(m[this] as? MetadataValue.ArrayVal)
|
||||
?.elements
|
||||
?.mapNotNull { (it as? MetadataValue.StringVal)?.value }
|
||||
|
||||
val arch = "general.architecture".str() ?: ARCH_LLAMA
|
||||
|
||||
// -------------- populate sections ----------------
|
||||
val basic = GgufMetadata.BasicInfo(
|
||||
uuid = "general.uuid".str(),
|
||||
name = "general.basename".str(),
|
||||
nameLabel = "general.name".str(),
|
||||
sizeLabel = "general.size_label".str()
|
||||
)
|
||||
|
||||
val author = GgufMetadata.AuthorInfo(
|
||||
organization = "general.organization".str(),
|
||||
author = "general.author".str(),
|
||||
doi = "general.doi".str(),
|
||||
url = "general.url".str(),
|
||||
repoUrl = "general.repo_url".str(),
|
||||
license = "general.license".str(),
|
||||
licenseLink = "general.license.link".str()
|
||||
).takeUnless {
|
||||
organization == null && author == null && doi == null &&
|
||||
url == null && repoUrl == null && license == null && licenseLink == null
|
||||
}
|
||||
|
||||
val additional = GgufMetadata.AdditionalInfo(
|
||||
type = "general.type".str(),
|
||||
description = "general.description".str(),
|
||||
tags = "general.tags".strList(),
|
||||
languages = "general.languages".strList()
|
||||
).takeUnless {
|
||||
type == null && description == null && tags == null && languages == null
|
||||
}
|
||||
|
||||
val architectureInfo = GgufMetadata.ArchitectureInfo(
|
||||
architecture = arch,
|
||||
fileType = "general.file_type".u32(),
|
||||
vocabSize = "$arch.vocab_size".u32(),
|
||||
finetune = "general.finetune".str(),
|
||||
quantizationVersion = "general.quantization_version".u32()
|
||||
).takeUnless { fileType == null && vocabSize == null && finetune == null && quantizationVersion == null }
|
||||
|
||||
val baseModels = buildList {
|
||||
val n = "general.base_model.count".u32() ?: 0
|
||||
for (i in 0 until n) {
|
||||
fun k(s: String) = "general.base_model.$i.$s"
|
||||
add(
|
||||
GgufMetadata.BaseModelInfo(
|
||||
name = k("name").str(),
|
||||
author = k("author").str(),
|
||||
version = k("version").str(),
|
||||
organization = k("organization").str(),
|
||||
url = k("url").str(),
|
||||
doi = k("doi").str(),
|
||||
uuid = k("uuid").str(),
|
||||
repoUrl = k("repo_url").str(),
|
||||
)
|
||||
)
|
||||
}
|
||||
}.takeIf { it.isNotEmpty() }
|
||||
|
||||
val tokenizer = GgufMetadata.TokenizerInfo(
|
||||
model = "tokenizer.ggml.model".str(),
|
||||
bosTokenId = "tokenizer.ggml.bos_token_id".u32(),
|
||||
eosTokenId = "tokenizer.ggml.eos_token_id".u32(),
|
||||
unknownTokenId = "tokenizer.ggml.unknown_token_id".u32(),
|
||||
paddingTokenId = "tokenizer.ggml.padding_token_id".u32(),
|
||||
addBosToken = "tokenizer.ggml.add_bos_token".bool(),
|
||||
addEosToken = "tokenizer.ggml.add_eos_token".bool(),
|
||||
chatTemplate = "tokenizer.chat_template".str()
|
||||
).takeUnless { model == null && bosTokenId == null && eosTokenId == null &&
|
||||
unknownTokenId == null && paddingTokenId == null &&
|
||||
addBosToken == null && addEosToken == null && chatTemplate == null
|
||||
}
|
||||
|
||||
val dimensions = GgufMetadata.DimensionsInfo(
|
||||
contextLength = "$arch.context_length".u32(),
|
||||
embeddingSize = "$arch.embedding_length".u32(),
|
||||
blockCount = "$arch.block_count".u32(),
|
||||
feedForwardSize = "$arch.feed_forward_length".u32()
|
||||
).takeUnless { contextLength == null && embeddingSize == null && blockCount == null && feedForwardSize == null }
|
||||
|
||||
val attention = GgufMetadata.AttentionInfo(
|
||||
headCount = "$arch.attention.head_count".u32(),
|
||||
headCountKv = "$arch.attention.head_count_kv".u32(),
|
||||
keyLength = "$arch.attention.key_length".u32(),
|
||||
valueLength = "$arch.attention.value_length".u32(),
|
||||
layerNormEpsilon = "$arch.attention.layer_norm_epsilon".f32(),
|
||||
layerNormRmsEpsilon = "$arch.attention.layer_norm_rms_epsilon".f32(),
|
||||
).takeUnless { headCount == null && headCountKv == null && keyLength == null && valueLength == null &&
|
||||
layerNormEpsilon == null && layerNormRmsEpsilon == null
|
||||
}
|
||||
|
||||
val rope = GgufMetadata.RopeInfo(
|
||||
frequencyBase = "$arch.rope.freq_base".f32(),
|
||||
dimensionCount = "$arch.rope.dimension_count".u32(),
|
||||
scalingType = "$arch.rope.scaling.type".str(),
|
||||
scalingFactor = "$arch.rope.scaling.factor".f32(),
|
||||
attnFactor = "$arch.rope.scaling.attn_factor".f32(),
|
||||
originalContextLength = "$arch.rope.scaling.original_context_length".u32(),
|
||||
finetuned = "$arch.rope.scaling.finetuned".bool()
|
||||
).takeUnless { frequencyBase == null && dimensionCount == null &&
|
||||
scalingType == null && scalingFactor == null && attnFactor == null &&
|
||||
originalContextLength == null && finetuned == null
|
||||
}
|
||||
|
||||
val experts = GgufMetadata.ExpertsInfo(
|
||||
count = "$arch.expert_count".u32(),
|
||||
usedCount = "$arch.expert_used_count".u32()
|
||||
).takeUnless { count == null && usedCount == null }
|
||||
|
||||
return GgufMetadata(
|
||||
version = version,
|
||||
tensorCount = tensorCnt,
|
||||
kvCount = kvCnt,
|
||||
basic = basic,
|
||||
author = author,
|
||||
additional = additional,
|
||||
architecture = architectureInfo,
|
||||
baseModels = baseModels,
|
||||
tokenizer = tokenizer,
|
||||
dimensions = dimensions,
|
||||
attention = attention,
|
||||
rope = rope,
|
||||
experts = experts
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively parses a metadata value of the given type from the input stream.
|
||||
* @param input The input stream positioned at the start of the value.
|
||||
* @param type The metadata value type to parse.
|
||||
*/
|
||||
private fun parseValue(input: InputStream, type: MetadataType): MetadataValue = when (type) {
|
||||
MetadataType.UINT8 -> {
|
||||
// 1-byte unsigned integer
|
||||
val byteVal = input.read()
|
||||
if (byteVal == -1) throw IOException("Unexpected EOF while reading uint8 value.")
|
||||
MetadataValue.UInt8(byteVal.toUByte())
|
||||
}
|
||||
MetadataType.INT8 -> {
|
||||
// 1-byte signed integer
|
||||
val byteVal = input.read()
|
||||
if (byteVal == -1) throw IOException("Unexpected EOF while reading int8 value.")
|
||||
MetadataValue.Int8(byteVal.toByte())
|
||||
}
|
||||
MetadataType.UINT16 -> {
|
||||
// 2-byte unsigned integer (little-endian)
|
||||
val bytes = ByteArray(2)
|
||||
if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading uint16 value.")
|
||||
// Combine two bytes (little-endian) into an unsigned 16-bit value
|
||||
val u16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF)
|
||||
MetadataValue.UInt16(u16.toUShort())
|
||||
}
|
||||
MetadataType.INT16 -> {
|
||||
// 2-byte signed integer (little-endian)
|
||||
val bytes = ByteArray(2)
|
||||
if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading int16 value.")
|
||||
// Combine to 16-bit and interpret as signed
|
||||
val i16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF)
|
||||
MetadataValue.Int16(i16.toShort())
|
||||
}
|
||||
MetadataType.UINT32 -> {
|
||||
// 4-byte unsigned integer (little-endian)
|
||||
val bytes = ByteArray(4)
|
||||
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading uint32 value.")
|
||||
// Combine four bytes into a 32-bit value (as Long to avoid overflow), then convert to UInt
|
||||
val u32 = (bytes[3].toLong() and 0xFFL shl 24) or
|
||||
(bytes[2].toLong() and 0xFFL shl 16) or
|
||||
(bytes[1].toLong() and 0xFFL shl 8) or
|
||||
(bytes[0].toLong() and 0xFFL)
|
||||
MetadataValue.UInt32(u32.toUInt())
|
||||
}
|
||||
MetadataType.INT32 -> {
|
||||
// 4-byte signed integer (little-endian)
|
||||
val bytes = ByteArray(4)
|
||||
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading int32 value.")
|
||||
// Combine four bytes into a 32-bit signed int
|
||||
val i32 = (bytes[3].toInt() and 0xFF shl 24) or
|
||||
(bytes[2].toInt() and 0xFF shl 16) or
|
||||
(bytes[1].toInt() and 0xFF shl 8) or
|
||||
(bytes[0].toInt() and 0xFF)
|
||||
MetadataValue.Int32(i32)
|
||||
}
|
||||
MetadataType.FLOAT32 -> {
|
||||
// 4-byte IEEE 754 float (little-endian)
|
||||
val bytes = ByteArray(4)
|
||||
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading float32 value.")
|
||||
// Assemble 4 bytes into a 32-bit int bit-pattern, then convert to Float
|
||||
val bits = (bytes[3].toInt() and 0xFF shl 24) or
|
||||
(bytes[2].toInt() and 0xFF shl 16) or
|
||||
(bytes[1].toInt() and 0xFF shl 8) or
|
||||
(bytes[0].toInt() and 0xFF)
|
||||
val floatVal = Float.fromBits(bits)
|
||||
MetadataValue.Float32(floatVal)
|
||||
}
|
||||
MetadataType.BOOL -> {
|
||||
// 1-byte boolean (0 = false, 1 = true)
|
||||
val byteVal = input.read()
|
||||
if (byteVal == -1) throw IOException("Unexpected EOF while reading boolean value.")
|
||||
if (byteVal != 0 && byteVal != 1) {
|
||||
throw IOException("Invalid boolean value: $byteVal (must be 0 or 1).")
|
||||
}
|
||||
MetadataValue.Bool(byteVal != 0)
|
||||
}
|
||||
MetadataType.STRING -> {
|
||||
// UTF-8 string (length-prefixed with 8-byte length)
|
||||
val str = readString(input)
|
||||
MetadataValue.StringVal(str)
|
||||
}
|
||||
MetadataType.ARRAY -> {
|
||||
val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
|
||||
val len = readLittleLong(input)
|
||||
val count = len.toInt()
|
||||
|
||||
if (arraySummariseThreshold >= 0 && count > arraySummariseThreshold) {
|
||||
// fast‑forward without allocation
|
||||
repeat(count) { skipValue(input, elemType) }
|
||||
MetadataValue.StringVal("Array($elemType, $count items) /* summarised */")
|
||||
} else {
|
||||
val list = ArrayList<MetadataValue>(count)
|
||||
repeat(count) { list += parseValue(input, elemType) }
|
||||
MetadataValue.ArrayVal(elemType, list)
|
||||
}
|
||||
}
|
||||
MetadataType.UINT64 -> {
|
||||
// 8-byte unsigned integer (little-endian)
|
||||
val bytes = ByteArray(8)
|
||||
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading uint64 value.")
|
||||
// Combine 8 bytes into an unsigned 64-bit (ULong). Use ULong for full 0 to 2^64-1 range.
|
||||
val u64 = (bytes[7].toULong() and 0xFFuL shl 56) or
|
||||
(bytes[6].toULong() and 0xFFuL shl 48) or
|
||||
(bytes[5].toULong() and 0xFFuL shl 40) or
|
||||
(bytes[4].toULong() and 0xFFuL shl 32) or
|
||||
(bytes[3].toULong() and 0xFFuL shl 24) or
|
||||
(bytes[2].toULong() and 0xFFuL shl 16) or
|
||||
(bytes[1].toULong() and 0xFFuL shl 8) or
|
||||
(bytes[0].toULong() and 0xFFuL)
|
||||
MetadataValue.UInt64(u64)
|
||||
}
|
||||
MetadataType.INT64 -> {
|
||||
// 8-byte signed integer (little-endian)
|
||||
val bytes = ByteArray(8)
|
||||
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading int64 value.")
|
||||
// Combine 8 bytes into a signed 64-bit value (Long)
|
||||
val i64 = (bytes[7].toLong() and 0xFFL shl 56) or
|
||||
(bytes[6].toLong() and 0xFFL shl 48) or
|
||||
(bytes[5].toLong() and 0xFFL shl 40) or
|
||||
(bytes[4].toLong() and 0xFFL shl 32) or
|
||||
(bytes[3].toLong() and 0xFFL shl 24) or
|
||||
(bytes[2].toLong() and 0xFFL shl 16) or
|
||||
(bytes[1].toLong() and 0xFFL shl 8) or
|
||||
(bytes[0].toLong() and 0xFFL)
|
||||
MetadataValue.Int64(i64)
|
||||
}
|
||||
MetadataType.FLOAT64 -> {
|
||||
// 8-byte IEEE 754 double (little-endian)
|
||||
val bytes = ByteArray(8)
|
||||
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading float64 value.")
|
||||
// Assemble 8 bytes into a 64-bit bit-pattern, then convert to Double
|
||||
val bits = (bytes[7].toLong() and 0xFFL shl 56) or
|
||||
(bytes[6].toLong() and 0xFFL shl 48) or
|
||||
(bytes[5].toLong() and 0xFFL shl 40) or
|
||||
(bytes[4].toLong() and 0xFFL shl 32) or
|
||||
(bytes[3].toLong() and 0xFFL shl 24) or
|
||||
(bytes[2].toLong() and 0xFFL shl 16) or
|
||||
(bytes[1].toLong() and 0xFFL shl 8) or
|
||||
(bytes[0].toLong() and 0xFFL)
|
||||
val doubleVal = Double.fromBits(bits)
|
||||
MetadataValue.Float64(doubleVal)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private fun <T> T?.takeUnless(check: T.() -> Boolean): T? =
|
||||
this?.takeIf { !it.check() }
|
||||
|
||||
/** Helper: Skip a value in the stream without storing it (still maintains pointer). */
|
||||
private fun skipValue(input: InputStream, type: MetadataType) {
|
||||
when (type) {
|
||||
MetadataType.UINT8, MetadataType.INT8, MetadataType.BOOL -> input.skipFully(1)
|
||||
MetadataType.UINT16, MetadataType.INT16 -> input.skipFully(2)
|
||||
MetadataType.UINT32, MetadataType.INT32, MetadataType.FLOAT32 -> input.skipFully(4)
|
||||
MetadataType.UINT64, MetadataType.INT64, MetadataType.FLOAT64 -> input.skipFully(8)
|
||||
MetadataType.STRING -> {
|
||||
val len = readLittleLong(input); input.skipFully(len)
|
||||
}
|
||||
MetadataType.ARRAY -> {
|
||||
val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
|
||||
val len = readLittleLong(input)
|
||||
repeat(len.toInt()) { skipValue(input, elemType) } // recursive skip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Helper: Read an 8-byte little-endian unsigned value and return it as a signed Long (assuming it fits in 63 bits). */
|
||||
private fun readLittleLong(input: InputStream): Long {
|
||||
val bytes = ByteArray(8)
|
||||
input.readFully(bytes)
|
||||
|
||||
// Combine 8 bytes into a 64-bit value (Little Endian).
|
||||
// Note: If the value exceeds Long.MAX_VALUE (bit 63 is 1), this will produce a negative Long (two's complement).
|
||||
// In our context (lengths/counts), such extremely large values are not expected.
|
||||
return (bytes[7].toLong() and 0xFFL shl 56) or
|
||||
(bytes[6].toLong() and 0xFFL shl 48) or
|
||||
(bytes[5].toLong() and 0xFFL shl 40) or
|
||||
(bytes[4].toLong() and 0xFFL shl 32) or
|
||||
(bytes[3].toLong() and 0xFFL shl 24) or
|
||||
(bytes[2].toLong() and 0xFFL shl 16) or
|
||||
(bytes[1].toLong() and 0xFFL shl 8) or
|
||||
(bytes[0].toLong() and 0xFFL)
|
||||
}
|
||||
|
||||
/** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */
|
||||
private fun readString(input: InputStream): String {
|
||||
// Read 8-byte little-endian length (number of bytes in the string).
|
||||
val len = readLittleLong(input)
|
||||
if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len")
|
||||
|
||||
// Read the UTF-8 bytes of the given length.
|
||||
val buf = ByteArray(len.toInt())
|
||||
if (buf.isNotEmpty()) input.readFully(buf)
|
||||
return String(buf, Charsets.UTF_8)
|
||||
}
|
||||
|
||||
/** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */
|
||||
private fun littleEndianBytesToInt(bytes: ByteArray): Int {
|
||||
// Note: assumes bytes length is 4.
|
||||
return (bytes[3].toInt() and 0xFF shl 24) or
|
||||
(bytes[2].toInt() and 0xFF shl 16) or
|
||||
(bytes[1].toInt() and 0xFF shl 8) or
|
||||
(bytes[0].toInt() and 0xFF)
|
||||
}
|
||||
|
||||
/**
|
||||
* Robust skip that works the same on JDK 11 and Android’s desugared runtime.
|
||||
*
|
||||
* @param n Number of bytes to advance in the stream.
|
||||
* @throws IOException on premature EOF.
|
||||
*/
|
||||
private fun InputStream.skipFully(n: Long) {
|
||||
var remaining = n
|
||||
val scratch = ByteArray(8192) // read‑and‑toss buffer
|
||||
while (remaining > 0) {
|
||||
val skipped = skip(remaining)
|
||||
when {
|
||||
skipped > 0 -> remaining -= skipped // normal fast path
|
||||
skipped == 0L -> {
|
||||
// fallback: read and discard
|
||||
val read = read(scratch, 0, minOf(remaining, scratch.size.toLong()).toInt())
|
||||
if (read == -1) throw IOException("EOF while skipping $n bytes")
|
||||
remaining -= read
|
||||
}
|
||||
else -> throw IOException("Skip returned negative value")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extension that keeps reading until the requested number of bytes are filled.
|
||||
* Falls back to `read()` when `skip()` returns 0, which happens on some Android
|
||||
* streams.
|
||||
*
|
||||
* @param buf Destination buffer.
|
||||
* @param len Number of bytes to fill (defaults to `buf.size`).
|
||||
* @throws IOException on premature EOF.
|
||||
*/
|
||||
private fun InputStream.readFully(buf: ByteArray, len: Int = buf.size) {
|
||||
var off = 0
|
||||
while (off < len) {
|
||||
val n = read(buf, off, len - off)
|
||||
if (n == -1) throw IOException("EOF after $off of $len bytes")
|
||||
off += n
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Read EXACTLY `n` bytes or throw – never returns a partially‑filled array.
|
||||
* This is used for small fixed‑length reads (e.g. 4‑byte type codes).
|
||||
*
|
||||
* @throws IOException on premature EOF.
|
||||
*/
|
||||
private fun InputStream.readNBytesExact(n: Int): ByteArray {
|
||||
val buf = ByteArray(n)
|
||||
if (read(buf) != n) throw IOException("Unexpected EOF")
|
||||
return buf
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue