gguf: add GGUF metadata data holder and its corresponding extractor implementation
This commit is contained in:
parent
a9466c0370
commit
67499727ef
|
|
@ -4,6 +4,8 @@ import androidx.room.Entity
|
|||
import androidx.room.PrimaryKey
|
||||
import com.example.llama.revamp.data.model.ModelInfo
|
||||
|
||||
// TODO-han.yin: Add GgufMetaData
|
||||
|
||||
@Entity(tableName = "models")
|
||||
data class ModelEntity(
|
||||
@PrimaryKey
|
||||
|
|
@ -11,22 +13,14 @@ data class ModelEntity(
|
|||
val name: String,
|
||||
val path: String,
|
||||
val sizeInBytes: Long,
|
||||
val parameters: String?,
|
||||
val quantization: String?,
|
||||
val type: String?,
|
||||
val contextLength: Int?,
|
||||
val lastUsed: Long?,
|
||||
val dateAdded: Long
|
||||
val dateAdded: Long,
|
||||
val lastUsed: Long?
|
||||
) {
|
||||
fun toModelInfo() = ModelInfo(
|
||||
id = id,
|
||||
name = name,
|
||||
path = path,
|
||||
sizeInBytes = sizeInBytes,
|
||||
parameters = parameters,
|
||||
quantization = quantization,
|
||||
type = type,
|
||||
contextLength = contextLength,
|
||||
lastUsed = lastUsed
|
||||
lastUsed = lastUsed,
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,10 +10,6 @@ data class ModelInfo(
|
|||
val name: String,
|
||||
val path: String,
|
||||
val sizeInBytes: Long,
|
||||
val parameters: String?,
|
||||
val quantization: String?,
|
||||
val type: String?,
|
||||
val contextLength: Int?,
|
||||
val lastUsed: Long? = null
|
||||
) {
|
||||
val formattedSize: String
|
||||
|
|
|
|||
|
|
@ -8,11 +8,9 @@ import com.example.llama.revamp.data.local.ModelDao
|
|||
import com.example.llama.revamp.data.local.ModelEntity
|
||||
import com.example.llama.revamp.data.model.ModelInfo
|
||||
import com.example.llama.revamp.data.repository.ModelRepository.ImportProgressTracker
|
||||
import com.example.llama.revamp.util.GgufMetadataReader
|
||||
import com.example.llama.revamp.util.copyWithBuffer
|
||||
import com.example.llama.revamp.util.copyWithChannels
|
||||
import com.example.llama.revamp.util.extractModelTypeFromFilename
|
||||
import com.example.llama.revamp.util.extractParametersFromFilename
|
||||
import com.example.llama.revamp.util.extractQuantizationFromFilename
|
||||
import com.example.llama.revamp.util.formatSize
|
||||
import com.example.llama.revamp.util.getFileNameFromUri
|
||||
import com.example.llama.revamp.util.getFileSizeFromUri
|
||||
|
|
@ -188,10 +186,15 @@ class ModelRepositoryImpl @Inject constructor(
|
|||
}
|
||||
}
|
||||
|
||||
// Extract model parameters from filename
|
||||
val modelType = extractModelTypeFromFilename(fileName)
|
||||
val parameters = extractParametersFromFilename(fileName)
|
||||
val quantization = extractQuantizationFromFilename(fileName)
|
||||
// Extract GGUF metadata if possible
|
||||
val metadata = try {
|
||||
val filePath = modelFile.absolutePath
|
||||
Log.i(TAG, "Extracting GGUF Metadata from $filePath")
|
||||
GgufMetadataReader().readStructuredMetadata(filePath)
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Failed to extract GGUF metadata: ${e.message}", e)
|
||||
null
|
||||
}
|
||||
|
||||
// Create model entity and save via DAO
|
||||
ModelEntity(
|
||||
|
|
@ -199,12 +202,9 @@ class ModelRepositoryImpl @Inject constructor(
|
|||
name = fileName.substringBeforeLast('.'),
|
||||
path = modelFile.absolutePath,
|
||||
sizeInBytes = modelFile.length(),
|
||||
parameters = parameters,
|
||||
quantization = quantization,
|
||||
type = modelType,
|
||||
contextLength = DEFAULT_CONTEXT_SIZE,
|
||||
lastUsed = null,
|
||||
dateAdded = System.currentTimeMillis()
|
||||
// TODO-han.yin: add metadata here
|
||||
dateAdded = System.currentTimeMillis(),
|
||||
lastUsed = null
|
||||
).let {
|
||||
modelDao.insertModel(it)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,25 @@ import java.nio.ByteBuffer
|
|||
import java.nio.channels.Channels
|
||||
import java.nio.channels.ReadableByteChannel
|
||||
import java.nio.channels.WritableByteChannel
|
||||
import java.util.Locale
|
||||
|
||||
/**
|
||||
* Convert bytes into human readable sizes
|
||||
*/
|
||||
fun formatSize(sizeInBytes: Long) = when {
|
||||
sizeInBytes >= 1_000_000_000 -> {
|
||||
val sizeInGb = sizeInBytes / 1_000_000_000.0
|
||||
String.format(Locale.getDefault(), "%.2f GB", sizeInGb)
|
||||
}
|
||||
sizeInBytes >= 1_000_000 -> {
|
||||
val sizeInMb = sizeInBytes / 1_000_000.0
|
||||
String.format(Locale.getDefault(), "%.2f MB", sizeInMb)
|
||||
}
|
||||
else -> {
|
||||
val sizeInKb = sizeInBytes / 1_000.0
|
||||
String.format(Locale.getDefault(), "%.2f KB", sizeInKb)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the file name from a content URI
|
||||
|
|
|
|||
|
|
@ -0,0 +1,232 @@
|
|||
package com.example.llama.revamp.util
|
||||
|
||||
import java.io.IOException
|
||||
|
||||
|
||||
/**
|
||||
* Structured metadata of GGUF
|
||||
*/
|
||||
data class GgufMetadata(
|
||||
// Basic file info
|
||||
val version: GgufVersion,
|
||||
val tensorCount: Long,
|
||||
val kvCount: Long,
|
||||
|
||||
// General info
|
||||
val basic: BasicInfo,
|
||||
val author: AuthorInfo? = null,
|
||||
val additional: AdditionalInfo? = null,
|
||||
val architecture: ArchitectureInfo? = null,
|
||||
val baseModels: List<BaseModelInfo>? = null,
|
||||
val tokenizer: TokenizerInfo? = null,
|
||||
|
||||
// Derivative info
|
||||
val dimensions: DimensionsInfo? = null,
|
||||
val attention: AttentionInfo? = null,
|
||||
val rope: RopeInfo? = null,
|
||||
val experts: ExpertsInfo? = null
|
||||
) {
|
||||
/** Human-readable full model name + size */
|
||||
val fullModelName: String?
|
||||
get() = when {
|
||||
basic.nameLabel != null -> basic.nameLabel
|
||||
basic.name != null && basic.sizeLabel != null -> "${basic.name}-${basic.sizeLabel}"
|
||||
basic.name != null -> basic.name
|
||||
else -> null
|
||||
}
|
||||
|
||||
/** Human‑readable model name (spaces). */
|
||||
val primaryName: String?
|
||||
get() = basic.nameLabel
|
||||
?: baseModels?.firstNotNullOfOrNull { it.name }
|
||||
?: basic.name
|
||||
|
||||
/** CLI‑friendly slug (hyphens). */
|
||||
val primaryBasename: String?
|
||||
get() = basic.name
|
||||
?: baseModels?.firstNotNullOfOrNull { it.name?.replace(' ', '-') }
|
||||
|
||||
/** URL pointing to model homepage/repo. */
|
||||
val primaryUrl: String?
|
||||
get() = author?.url
|
||||
?: baseModels?.firstNotNullOfOrNull { it.url }
|
||||
|
||||
val primaryRepoUrl: String?
|
||||
get() = author?.repoUrl
|
||||
?: baseModels?.firstNotNullOfOrNull { it.repoUrl }
|
||||
|
||||
/** Organisation string. */
|
||||
val primaryOrganization: String?
|
||||
get() = author?.organization
|
||||
?: baseModels?.firstNotNullOfOrNull { it.organization }
|
||||
|
||||
/** Author string. */
|
||||
val primaryAuthor: String?
|
||||
get() = author?.author
|
||||
?: baseModels?.firstNotNullOfOrNull { it.author }
|
||||
|
||||
/** Context length with unit, e.g. “32768 tokens”. */
|
||||
val formattedContextLength: String?
|
||||
get() = dimensions?.contextLength?.let { "$it tokens" }
|
||||
|
||||
enum class GgufVersion(val code: Int, val label: String) {
|
||||
/** First public draft; little‑endian only, no alignment key. */
|
||||
LEGACY_V1(1, "Legacy v1"),
|
||||
|
||||
/** Added split‑file support and some extra metadata keys. */
|
||||
EXTENDED_V2(2, "Extended v2"),
|
||||
|
||||
/** Current spec: endian‑aware, mandatory alignment, fully validated. */
|
||||
VALIDATED_V3(3, "Validated v3");
|
||||
|
||||
companion object {
|
||||
fun fromCode(code: Int): GgufVersion =
|
||||
entries.firstOrNull { it.code == code }
|
||||
?: throw IOException("Unknown GGUF version code $code")
|
||||
}
|
||||
|
||||
override fun toString(): String = "$label (code=$code)"
|
||||
}
|
||||
|
||||
data class BasicInfo(
|
||||
val uuid: String? = null,
|
||||
val name: String? = null,
|
||||
val nameLabel: String? = null,
|
||||
val sizeLabel: String? = null, // Size label like "7B"
|
||||
)
|
||||
|
||||
data class AuthorInfo(
|
||||
val organization: String? = null,
|
||||
val author: String? = null,
|
||||
val doi: String? = null,
|
||||
val url: String? = null,
|
||||
val repoUrl: String? = null,
|
||||
val license: String? = null,
|
||||
val licenseLink: String? = null,
|
||||
)
|
||||
|
||||
data class AdditionalInfo(
|
||||
val type: String? = null,
|
||||
val description: String? = null,
|
||||
val tags: List<String>? = null,
|
||||
val languages: List<String>? = null,
|
||||
)
|
||||
|
||||
data class ArchitectureInfo(
|
||||
val architecture: String? = null,
|
||||
val fileType: Int? = null,
|
||||
val vocabSize: Int? = null,
|
||||
val finetune: String? = null,
|
||||
val quantizationVersion: Int? = null,
|
||||
)
|
||||
|
||||
data class BaseModelInfo(
|
||||
val name: String? = null,
|
||||
val author: String? = null,
|
||||
val version: String? = null,
|
||||
val organization: String? = null,
|
||||
val url: String? = null,
|
||||
val doi: String? = null,
|
||||
val uuid: String? = null,
|
||||
val repoUrl: String? = null,
|
||||
)
|
||||
|
||||
data class TokenizerInfo(
|
||||
val model: String? = null,
|
||||
val bosTokenId: Int? = null,
|
||||
val eosTokenId: Int? = null,
|
||||
val unknownTokenId: Int? = null,
|
||||
val paddingTokenId: Int? = null,
|
||||
val addBosToken: Boolean? = null,
|
||||
val addEosToken: Boolean? = null,
|
||||
val chatTemplate: String? = null,
|
||||
)
|
||||
|
||||
data class DimensionsInfo(
|
||||
val contextLength: Int? = null,
|
||||
val embeddingSize: Int? = null,
|
||||
val blockCount: Int? = null,
|
||||
val feedForwardSize: Int? = null,
|
||||
)
|
||||
|
||||
data class AttentionInfo(
|
||||
val headCount: Int? = null,
|
||||
val headCountKv: Int? = null,
|
||||
val keyLength: Int? = null,
|
||||
val valueLength: Int? = null,
|
||||
val layerNormEpsilon: Float? = null,
|
||||
val layerNormRmsEpsilon: Float? = null,
|
||||
)
|
||||
|
||||
data class RopeInfo(
|
||||
val frequencyBase: Float? = null,
|
||||
val dimensionCount: Int? = null,
|
||||
val scalingType: String? = null,
|
||||
val scalingFactor: Float? = null,
|
||||
val attnFactor: Float? = null,
|
||||
val originalContextLength: Int? = null,
|
||||
val finetuned: Boolean? = null,
|
||||
)
|
||||
|
||||
data class ExpertsInfo(
|
||||
val count: Int? = null,
|
||||
val usedCount: Int? = null,
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Numerical codes used by `general.file_type` (see llama.cpp repo's `constants.py`).
|
||||
* The `label` matches what llama‑cli prints.
|
||||
*/
|
||||
enum class FileType(val code: Int, val label: String) {
|
||||
ALL_F32(0, "all F32"),
|
||||
MOSTLY_F16(1, "F16"),
|
||||
MOSTLY_Q4_0(2, "Q4_0"),
|
||||
MOSTLY_Q4_1(3, "Q4_1"),
|
||||
// 4 removed
|
||||
MOSTLY_Q8_0(7, "Q8_0"),
|
||||
MOSTLY_Q5_0(8, "Q5_0"),
|
||||
MOSTLY_Q5_1(9, "Q5_1"),
|
||||
|
||||
/* K‑quants ------------------------------------------------------------ */
|
||||
MOSTLY_Q2_K (10, "Q2_K - Medium"),
|
||||
MOSTLY_Q3_K_S (11, "Q3_K - Small"),
|
||||
MOSTLY_Q3_K_M (12, "Q3_K - Medium"),
|
||||
MOSTLY_Q3_K_L (13, "Q3_K - Large"),
|
||||
MOSTLY_Q4_K_S (14, "Q4_K - Small"),
|
||||
MOSTLY_Q4_K_M (15, "Q4_K - Medium"),
|
||||
MOSTLY_Q5_K_S (16, "Q5_K - Small"),
|
||||
MOSTLY_Q5_K_M (17, "Q5_K - Medium"),
|
||||
MOSTLY_Q6_K (18, "Q6_K"),
|
||||
|
||||
/* IQ quants ----------------------------------------------------------- */
|
||||
MOSTLY_IQ2_XXS (19, "IQ2_XXS - 2.06 bpw"),
|
||||
MOSTLY_IQ2_XS (20, "IQ2_XS - 2.31 bpw"),
|
||||
MOSTLY_Q2_K_S (21, "Q2_K - Small"),
|
||||
MOSTLY_IQ3_XS (22, "IQ3_XS - 3.30 bpw"),
|
||||
MOSTLY_IQ3_XXS (23, "IQ3_XXS - 3.06 bpw"),
|
||||
MOSTLY_IQ1_S (24, "IQ1_S - 1.56 bpw"),
|
||||
MOSTLY_IQ4_NL (25, "IQ4_NL - 4.5 bpw"),
|
||||
MOSTLY_IQ3_S (26, "IQ3_S - 3.44 bpw"),
|
||||
MOSTLY_IQ3_M (27, "IQ3_M - 3.66 bpw"),
|
||||
MOSTLY_IQ2_S (28, "IQ2_S - 2.50 bpw"),
|
||||
MOSTLY_IQ2_M (29, "IQ2_M - 2.70 bpw"),
|
||||
MOSTLY_IQ4_XS (30, "IQ4_XS - 4.25 bpw"),
|
||||
MOSTLY_IQ1_M (31, "IQ1_M - 1.75 bpw"),
|
||||
|
||||
/* BF16 & Ternary ------------------------------------------------------ */
|
||||
MOSTLY_BF16 (32, "BF16"),
|
||||
MOSTLY_TQ1_0 (36, "TQ1_0 - 1.69 bpw ternary"),
|
||||
MOSTLY_TQ2_0 (37, "TQ2_0 - 2.06 bpw ternary"),
|
||||
|
||||
/* Special flag -------------------------------------------------------- */
|
||||
GUESSED(1024, "(guessed)"),
|
||||
|
||||
UNKNOWN(-1, "unknown");
|
||||
|
||||
companion object {
|
||||
private val map = entries.associateBy(FileType::code)
|
||||
|
||||
fun fromCode(code: Int?): FileType = map[code] ?: UNKNOWN
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,574 @@
|
|||
package com.example.llama.revamp.util
|
||||
|
||||
import java.io.File
|
||||
import java.io.InputStream
|
||||
import java.io.IOException
|
||||
|
||||
private val DEFAULT_SKIP_KEYS = setOf(
|
||||
"tokenizer.chat_template",
|
||||
"tokenizer.ggml.scores",
|
||||
"tokenizer.ggml.tokens",
|
||||
"tokenizer.ggml.token_type"
|
||||
)
|
||||
|
||||
/**
|
||||
* Utility class to read GGUF model files and extract metadata key-value pairs.
|
||||
* This parser reads the header and metadata of a GGUF v3 file (little-endian) and skips tensor data.
|
||||
*/
|
||||
class GgufMetadataReader(
|
||||
/** Keys whose value should be skipped entirely (not kept in the resulting map). */
|
||||
private val skipKeys: Set<String> = DEFAULT_SKIP_KEYS,
|
||||
/** If ≥0, arrays longer than this get summarised instead of materialised. -1 ⇒ never summarise. */
|
||||
private val arraySummariseThreshold: Int = 1_000
|
||||
) {
|
||||
companion object {
|
||||
private const val ARCH_LLAMA = "llama"
|
||||
}
|
||||
|
||||
/** Enum corresponding to GGUF metadata value types (for convenience and array element typing). */
|
||||
enum class MetadataType(val code: Int) {
|
||||
UINT8(0), INT8(1), UINT16(2), INT16(3),
|
||||
UINT32(4), INT32(5), FLOAT32(6), BOOL(7),
|
||||
STRING(8), ARRAY(9), UINT64(10), INT64(11), FLOAT64(12);
|
||||
companion object {
|
||||
private val codeMap = values().associateBy(MetadataType::code)
|
||||
fun fromCode(code: Int): MetadataType = codeMap[code]
|
||||
?: throw IOException("Unknown metadata value type code: $code")
|
||||
}
|
||||
}
|
||||
|
||||
/** Sealed class hierarchy for metadata values, providing type-safe representations for each GGUF metadata type. */
|
||||
sealed class MetadataValue {
|
||||
data class UInt8(val value: UByte) : MetadataValue() // 0: 8-bit unsigned int
|
||||
data class Int8(val value: Byte) : MetadataValue() // 1: 8-bit signed int
|
||||
data class UInt16(val value: UShort) : MetadataValue() // 2: 16-bit unsigned int (little-endian)
|
||||
data class Int16(val value: Short) : MetadataValue() // 3: 16-bit signed int (little-endian)
|
||||
data class UInt32(val value: UInt) : MetadataValue() // 4: 32-bit unsigned int (little-endian)
|
||||
data class Int32(val value: Int) : MetadataValue() // 5: 32-bit signed int (little-endian)
|
||||
data class Float32(val value: Float) : MetadataValue() // 6: 32-bit IEEE754 float
|
||||
data class Bool(val value: Boolean) : MetadataValue() // 7: Boolean (1-byte, 0=false, 1=true)
|
||||
data class StringVal(val value: String) : MetadataValue() // 8: UTF-8 string (length-prefixed)
|
||||
data class ArrayVal(val elementType: MetadataType, val elements: List<MetadataValue>) : MetadataValue()
|
||||
data class UInt64(val value: ULong) : MetadataValue() // 10: 64-bit unsigned int (little-endian)
|
||||
data class Int64(val value: Long) : MetadataValue() // 11: 64-bit signed int (little-endian)
|
||||
data class Float64(val value: Double) : MetadataValue() // 12: 64-bit IEEE754 double
|
||||
}
|
||||
|
||||
/* Convert MetadataValue to plain Kotlin primitives for allMetadata map */
|
||||
private fun MetadataValue.toPrimitive(): Any = when (this) {
|
||||
is MetadataValue.UInt8 -> value
|
||||
is MetadataValue.Int8 -> value
|
||||
is MetadataValue.UInt16 -> value
|
||||
is MetadataValue.Int16 -> value
|
||||
is MetadataValue.UInt32 -> value
|
||||
is MetadataValue.Int32 -> value
|
||||
is MetadataValue.Float32 -> value
|
||||
is MetadataValue.Bool -> value
|
||||
is MetadataValue.StringVal -> value
|
||||
is MetadataValue.UInt64 -> value
|
||||
is MetadataValue.Int64 -> value
|
||||
is MetadataValue.Float64 -> value
|
||||
is MetadataValue.ArrayVal -> elements.map { it.toPrimitive() }
|
||||
}
|
||||
|
||||
/**
|
||||
* High‑level entry point: parses a `.gguf` file on disk and returns the fully
|
||||
* populated [GgufMetadata] tree.
|
||||
*
|
||||
* Steps performed internally:
|
||||
* 1. Reads and validates the 8‑byte header (`"GGUF"` magic + version).
|
||||
* 2. Streams through the key‑value section, skipping large blobs if the key
|
||||
* appears in [skipKeys] or if an array exceeds [arraySummariseThreshold].
|
||||
* 3. Converts the resulting raw map into strongly‑typed sub‑structures
|
||||
* (basic info, tokenizer, rope, etc.).
|
||||
*
|
||||
* The method is STREAMING‑ONLY: tensors are never mapped or loaded into
|
||||
* memory, so even multi‑GB model files can be processed in < 50 ms.
|
||||
*
|
||||
* @param path Absolute or relative filesystem path to a `.gguf` file.
|
||||
* @return A [GgufMetadata] instance containing all recognised metadata plus
|
||||
* an `allMetadata` map with any keys that were not given a dedicated
|
||||
* field.
|
||||
* @throws IOException if the file is not GGUF, the version is unsupported,
|
||||
* or the metadata block is truncated / corrupt.
|
||||
*/
|
||||
fun readStructuredMetadata(path: String): GgufMetadata {
|
||||
File(path).inputStream().buffered().use { input ->
|
||||
// ── 1. header ──────────────────────────────────────────────────────────
|
||||
// throws on mismatch
|
||||
val version = ensureMagicAndVersion(input)
|
||||
val tensorCount = readLittleLong(input)
|
||||
val kvCount = readLittleLong(input)
|
||||
|
||||
// ── 2. metadata map (reuse our raw parser, but we need access to the stream) ──
|
||||
val meta = readMetaMap(input, kvCount) // <String, MetadataValue>
|
||||
|
||||
// ── 3. build structured object ────────────────────────────────────────
|
||||
return buildStructured(meta, version, tensorCount, kvCount)
|
||||
}
|
||||
}
|
||||
|
||||
/** Reads the 4‑byte magic + 4‑byte version; throws if magic ≠ "GGUF". */
|
||||
private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion {
|
||||
val magic = ByteArray(4)
|
||||
if (input.read(magic) != 4) throw IOException("File too short (no magic)")
|
||||
if (!magic.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46))) // "GGUF"
|
||||
throw IOException("Not a GGUF file (bad magic)")
|
||||
return GgufMetadata.GgufVersion.fromCode(readLEUInt32(input))
|
||||
}
|
||||
|
||||
/**
|
||||
* Read an unsigned 32‑bit little‑endian integer.
|
||||
*
|
||||
* @throws IOException if fewer than four bytes are available.
|
||||
*/
|
||||
private fun readLEUInt32(input: InputStream): Int {
|
||||
val b0 = input.read(); val b1 = input.read(); val b2 = input.read(); val b3 = input.read()
|
||||
if (b3 == -1) throw IOException("Unexpected EOF while reading UInt32")
|
||||
return (b3 and 0xFF shl 24) or
|
||||
(b2 and 0xFF shl 16) or
|
||||
(b1 and 0xFF shl 8) or
|
||||
(b0 and 0xFF)
|
||||
}
|
||||
|
||||
/**
|
||||
* Low‑level helper that reads the entire “key-value” section from the current
|
||||
* stream position.
|
||||
*
|
||||
* @param input Open stream positioned JUST AFTER the header.
|
||||
* @param kvCnt Number of key‑value pairs (taken from the header).
|
||||
* @return Mutable map with one [MetadataValue] for every key that is NOT skipped.
|
||||
*
|
||||
* The function honours [skipKeys] and [arraySummariseThreshold] by invoking
|
||||
* [skipValue] or [parseValue] accordingly.
|
||||
*/
|
||||
private fun readMetaMap(input: InputStream, kvCnt: Long): Map<String, MetadataValue> {
|
||||
val map = mutableMapOf<String, MetadataValue>()
|
||||
repeat(kvCnt.toInt()) {
|
||||
val key = readString(input)
|
||||
val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
|
||||
if (key in skipKeys) {
|
||||
skipValue(input, valueT)
|
||||
} else {
|
||||
map[key] = parseValue(input, valueT)
|
||||
}
|
||||
}
|
||||
return map
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a flat [Map]<[String], [MetadataValue]> into the strongly‑typed
|
||||
* [GgufMetadata] tree used by the rest of the app.
|
||||
*
|
||||
* Only the keys listed in the spec are copied into dedicated data classes;
|
||||
* everything else is preserved in `GgufMetadata.allMetadata`.
|
||||
*
|
||||
* @param m Raw key/value map.
|
||||
* @param version GGUF file‑format version (enum).
|
||||
* @param tensorCnt Number of tensors (from the header).
|
||||
* @param kvCnt Total metadata pair count (from the header).
|
||||
*/
|
||||
private fun buildStructured(
|
||||
m: Map<String, MetadataValue>,
|
||||
version: GgufMetadata.GgufVersion,
|
||||
tensorCnt: Long,
|
||||
kvCnt: Long
|
||||
): GgufMetadata {
|
||||
// ---------- helpers ----------
|
||||
fun String.str() = (m[this] as? MetadataValue.StringVal)?.value
|
||||
fun String.bool() = (m[this] as? MetadataValue.Bool)?.value
|
||||
fun String.i32() = (m[this] as? MetadataValue.Int32)?.value
|
||||
fun String.u32() = (m[this] as? MetadataValue.UInt32)?.value?.toInt()
|
||||
fun String.f32() = (m[this] as? MetadataValue.Float32)?.value
|
||||
fun String.f64() = (m[this] as? MetadataValue.Float64)?.value?.toFloat()
|
||||
fun String.strList(): List<String>? =
|
||||
(m[this] as? MetadataValue.ArrayVal)
|
||||
?.elements
|
||||
?.mapNotNull { (it as? MetadataValue.StringVal)?.value }
|
||||
|
||||
val arch = "general.architecture".str() ?: ARCH_LLAMA
|
||||
|
||||
// -------------- populate sections ----------------
|
||||
val basic = GgufMetadata.BasicInfo(
|
||||
uuid = "general.uuid".str(),
|
||||
name = "general.basename".str(),
|
||||
nameLabel = "general.name".str(),
|
||||
sizeLabel = "general.size_label".str()
|
||||
)
|
||||
|
||||
val author = GgufMetadata.AuthorInfo(
|
||||
organization = "general.organization".str(),
|
||||
author = "general.author".str(),
|
||||
doi = "general.doi".str(),
|
||||
url = "general.url".str(),
|
||||
repoUrl = "general.repo_url".str(),
|
||||
license = "general.license".str(),
|
||||
licenseLink = "general.license.link".str()
|
||||
).takeUnless {
|
||||
organization == null && author == null && doi == null &&
|
||||
url == null && repoUrl == null && license == null && licenseLink == null
|
||||
}
|
||||
|
||||
val additional = GgufMetadata.AdditionalInfo(
|
||||
type = "general.type".str(),
|
||||
description = "general.description".str(),
|
||||
tags = "general.tags".strList(),
|
||||
languages = "general.languages".strList()
|
||||
).takeUnless {
|
||||
type == null && description == null && tags == null && languages == null
|
||||
}
|
||||
|
||||
val architectureInfo = GgufMetadata.ArchitectureInfo(
|
||||
architecture = arch,
|
||||
fileType = "general.file_type".u32(),
|
||||
vocabSize = "$arch.vocab_size".u32(),
|
||||
finetune = "general.finetune".str(),
|
||||
quantizationVersion = "general.quantization_version".u32()
|
||||
).takeUnless { fileType == null && vocabSize == null && finetune == null && quantizationVersion == null }
|
||||
|
||||
val baseModels = buildList {
|
||||
val n = "general.base_model.count".u32() ?: 0
|
||||
for (i in 0 until n) {
|
||||
fun k(s: String) = "general.base_model.$i.$s"
|
||||
add(
|
||||
GgufMetadata.BaseModelInfo(
|
||||
name = k("name").str(),
|
||||
author = k("author").str(),
|
||||
version = k("version").str(),
|
||||
organization = k("organization").str(),
|
||||
url = k("url").str(),
|
||||
doi = k("doi").str(),
|
||||
uuid = k("uuid").str(),
|
||||
repoUrl = k("repo_url").str(),
|
||||
)
|
||||
)
|
||||
}
|
||||
}.takeIf { it.isNotEmpty() }
|
||||
|
||||
val tokenizer = GgufMetadata.TokenizerInfo(
|
||||
model = "tokenizer.ggml.model".str(),
|
||||
bosTokenId = "tokenizer.ggml.bos_token_id".u32(),
|
||||
eosTokenId = "tokenizer.ggml.eos_token_id".u32(),
|
||||
unknownTokenId = "tokenizer.ggml.unknown_token_id".u32(),
|
||||
paddingTokenId = "tokenizer.ggml.padding_token_id".u32(),
|
||||
addBosToken = "tokenizer.ggml.add_bos_token".bool(),
|
||||
addEosToken = "tokenizer.ggml.add_eos_token".bool(),
|
||||
chatTemplate = "tokenizer.chat_template".str()
|
||||
).takeUnless { model == null && bosTokenId == null && eosTokenId == null &&
|
||||
unknownTokenId == null && paddingTokenId == null &&
|
||||
addBosToken == null && addEosToken == null && chatTemplate == null
|
||||
}
|
||||
|
||||
val dimensions = GgufMetadata.DimensionsInfo(
|
||||
contextLength = "$arch.context_length".u32(),
|
||||
embeddingSize = "$arch.embedding_length".u32(),
|
||||
blockCount = "$arch.block_count".u32(),
|
||||
feedForwardSize = "$arch.feed_forward_length".u32()
|
||||
).takeUnless { contextLength == null && embeddingSize == null && blockCount == null && feedForwardSize == null }
|
||||
|
||||
val attention = GgufMetadata.AttentionInfo(
|
||||
headCount = "$arch.attention.head_count".u32(),
|
||||
headCountKv = "$arch.attention.head_count_kv".u32(),
|
||||
keyLength = "$arch.attention.key_length".u32(),
|
||||
valueLength = "$arch.attention.value_length".u32(),
|
||||
layerNormEpsilon = "$arch.attention.layer_norm_epsilon".f32(),
|
||||
layerNormRmsEpsilon = "$arch.attention.layer_norm_rms_epsilon".f32(),
|
||||
).takeUnless { headCount == null && headCountKv == null && keyLength == null && valueLength == null &&
|
||||
layerNormEpsilon == null && layerNormRmsEpsilon == null
|
||||
}
|
||||
|
||||
val rope = GgufMetadata.RopeInfo(
|
||||
frequencyBase = "$arch.rope.freq_base".f32(),
|
||||
dimensionCount = "$arch.rope.dimension_count".u32(),
|
||||
scalingType = "$arch.rope.scaling.type".str(),
|
||||
scalingFactor = "$arch.rope.scaling.factor".f32(),
|
||||
attnFactor = "$arch.rope.scaling.attn_factor".f32(),
|
||||
originalContextLength = "$arch.rope.scaling.original_context_length".u32(),
|
||||
finetuned = "$arch.rope.scaling.finetuned".bool()
|
||||
).takeUnless { frequencyBase == null && dimensionCount == null &&
|
||||
scalingType == null && scalingFactor == null && attnFactor == null &&
|
||||
originalContextLength == null && finetuned == null
|
||||
}
|
||||
|
||||
val experts = GgufMetadata.ExpertsInfo(
|
||||
count = "$arch.expert_count".u32(),
|
||||
usedCount = "$arch.expert_used_count".u32()
|
||||
).takeUnless { count == null && usedCount == null }
|
||||
|
||||
return GgufMetadata(
|
||||
version = version,
|
||||
tensorCount = tensorCnt,
|
||||
kvCount = kvCnt,
|
||||
basic = basic,
|
||||
author = author,
|
||||
additional = additional,
|
||||
architecture = architectureInfo,
|
||||
baseModels = baseModels,
|
||||
tokenizer = tokenizer,
|
||||
dimensions = dimensions,
|
||||
attention = attention,
|
||||
rope = rope,
|
||||
experts = experts
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively parses a metadata value of the given type from the input stream.
|
||||
* @param input The input stream positioned at the start of the value.
|
||||
* @param type The metadata value type to parse.
|
||||
*/
|
||||
private fun parseValue(input: InputStream, type: MetadataType): MetadataValue = when (type) {
|
||||
MetadataType.UINT8 -> {
|
||||
// 1-byte unsigned integer
|
||||
val byteVal = input.read()
|
||||
if (byteVal == -1) throw IOException("Unexpected EOF while reading uint8 value.")
|
||||
MetadataValue.UInt8(byteVal.toUByte())
|
||||
}
|
||||
MetadataType.INT8 -> {
|
||||
// 1-byte signed integer
|
||||
val byteVal = input.read()
|
||||
if (byteVal == -1) throw IOException("Unexpected EOF while reading int8 value.")
|
||||
MetadataValue.Int8(byteVal.toByte())
|
||||
}
|
||||
MetadataType.UINT16 -> {
|
||||
// 2-byte unsigned integer (little-endian)
|
||||
val bytes = ByteArray(2)
|
||||
if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading uint16 value.")
|
||||
// Combine two bytes (little-endian) into an unsigned 16-bit value
|
||||
val u16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF)
|
||||
MetadataValue.UInt16(u16.toUShort())
|
||||
}
|
||||
MetadataType.INT16 -> {
|
||||
// 2-byte signed integer (little-endian)
|
||||
val bytes = ByteArray(2)
|
||||
if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading int16 value.")
|
||||
// Combine to 16-bit and interpret as signed
|
||||
val i16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF)
|
||||
MetadataValue.Int16(i16.toShort())
|
||||
}
|
||||
MetadataType.UINT32 -> {
|
||||
// 4-byte unsigned integer (little-endian)
|
||||
val bytes = ByteArray(4)
|
||||
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading uint32 value.")
|
||||
// Combine four bytes into a 32-bit value (as Long to avoid overflow), then convert to UInt
|
||||
val u32 = (bytes[3].toLong() and 0xFFL shl 24) or
|
||||
(bytes[2].toLong() and 0xFFL shl 16) or
|
||||
(bytes[1].toLong() and 0xFFL shl 8) or
|
||||
(bytes[0].toLong() and 0xFFL)
|
||||
MetadataValue.UInt32(u32.toUInt())
|
||||
}
|
||||
MetadataType.INT32 -> {
|
||||
// 4-byte signed integer (little-endian)
|
||||
val bytes = ByteArray(4)
|
||||
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading int32 value.")
|
||||
// Combine four bytes into a 32-bit signed int
|
||||
val i32 = (bytes[3].toInt() and 0xFF shl 24) or
|
||||
(bytes[2].toInt() and 0xFF shl 16) or
|
||||
(bytes[1].toInt() and 0xFF shl 8) or
|
||||
(bytes[0].toInt() and 0xFF)
|
||||
MetadataValue.Int32(i32)
|
||||
}
|
||||
MetadataType.FLOAT32 -> {
|
||||
// 4-byte IEEE 754 float (little-endian)
|
||||
val bytes = ByteArray(4)
|
||||
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading float32 value.")
|
||||
// Assemble 4 bytes into a 32-bit int bit-pattern, then convert to Float
|
||||
val bits = (bytes[3].toInt() and 0xFF shl 24) or
|
||||
(bytes[2].toInt() and 0xFF shl 16) or
|
||||
(bytes[1].toInt() and 0xFF shl 8) or
|
||||
(bytes[0].toInt() and 0xFF)
|
||||
val floatVal = Float.fromBits(bits)
|
||||
MetadataValue.Float32(floatVal)
|
||||
}
|
||||
MetadataType.BOOL -> {
|
||||
// 1-byte boolean (0 = false, 1 = true)
|
||||
val byteVal = input.read()
|
||||
if (byteVal == -1) throw IOException("Unexpected EOF while reading boolean value.")
|
||||
if (byteVal != 0 && byteVal != 1) {
|
||||
throw IOException("Invalid boolean value: $byteVal (must be 0 or 1).")
|
||||
}
|
||||
MetadataValue.Bool(byteVal != 0)
|
||||
}
|
||||
MetadataType.STRING -> {
|
||||
// UTF-8 string (length-prefixed with 8-byte length)
|
||||
val str = readString(input)
|
||||
MetadataValue.StringVal(str)
|
||||
}
|
||||
MetadataType.ARRAY -> {
|
||||
val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
|
||||
val len = readLittleLong(input)
|
||||
val count = len.toInt()
|
||||
|
||||
if (arraySummariseThreshold >= 0 && count > arraySummariseThreshold) {
|
||||
// fast‑forward without allocation
|
||||
repeat(count) { skipValue(input, elemType) }
|
||||
MetadataValue.StringVal("Array($elemType, $count items) /* summarised */")
|
||||
} else {
|
||||
val list = ArrayList<MetadataValue>(count)
|
||||
repeat(count) { list += parseValue(input, elemType) }
|
||||
MetadataValue.ArrayVal(elemType, list)
|
||||
}
|
||||
}
|
||||
MetadataType.UINT64 -> {
|
||||
// 8-byte unsigned integer (little-endian)
|
||||
val bytes = ByteArray(8)
|
||||
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading uint64 value.")
|
||||
// Combine 8 bytes into an unsigned 64-bit (ULong). Use ULong for full 0 to 2^64-1 range.
|
||||
val u64 = (bytes[7].toULong() and 0xFFuL shl 56) or
|
||||
(bytes[6].toULong() and 0xFFuL shl 48) or
|
||||
(bytes[5].toULong() and 0xFFuL shl 40) or
|
||||
(bytes[4].toULong() and 0xFFuL shl 32) or
|
||||
(bytes[3].toULong() and 0xFFuL shl 24) or
|
||||
(bytes[2].toULong() and 0xFFuL shl 16) or
|
||||
(bytes[1].toULong() and 0xFFuL shl 8) or
|
||||
(bytes[0].toULong() and 0xFFuL)
|
||||
MetadataValue.UInt64(u64)
|
||||
}
|
||||
MetadataType.INT64 -> {
|
||||
// 8-byte signed integer (little-endian)
|
||||
val bytes = ByteArray(8)
|
||||
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading int64 value.")
|
||||
// Combine 8 bytes into a signed 64-bit value (Long)
|
||||
val i64 = (bytes[7].toLong() and 0xFFL shl 56) or
|
||||
(bytes[6].toLong() and 0xFFL shl 48) or
|
||||
(bytes[5].toLong() and 0xFFL shl 40) or
|
||||
(bytes[4].toLong() and 0xFFL shl 32) or
|
||||
(bytes[3].toLong() and 0xFFL shl 24) or
|
||||
(bytes[2].toLong() and 0xFFL shl 16) or
|
||||
(bytes[1].toLong() and 0xFFL shl 8) or
|
||||
(bytes[0].toLong() and 0xFFL)
|
||||
MetadataValue.Int64(i64)
|
||||
}
|
||||
MetadataType.FLOAT64 -> {
|
||||
// 8-byte IEEE 754 double (little-endian)
|
||||
val bytes = ByteArray(8)
|
||||
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading float64 value.")
|
||||
// Assemble 8 bytes into a 64-bit bit-pattern, then convert to Double
|
||||
val bits = (bytes[7].toLong() and 0xFFL shl 56) or
|
||||
(bytes[6].toLong() and 0xFFL shl 48) or
|
||||
(bytes[5].toLong() and 0xFFL shl 40) or
|
||||
(bytes[4].toLong() and 0xFFL shl 32) or
|
||||
(bytes[3].toLong() and 0xFFL shl 24) or
|
||||
(bytes[2].toLong() and 0xFFL shl 16) or
|
||||
(bytes[1].toLong() and 0xFFL shl 8) or
|
||||
(bytes[0].toLong() and 0xFFL)
|
||||
val doubleVal = Double.fromBits(bits)
|
||||
MetadataValue.Float64(doubleVal)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private fun <T> T?.takeUnless(check: T.() -> Boolean): T? =
|
||||
this?.takeIf { !it.check() }
|
||||
|
||||
/** Helper: Skip a value in the stream without storing it (still maintains pointer). */
|
||||
private fun skipValue(input: InputStream, type: MetadataType) {
|
||||
when (type) {
|
||||
MetadataType.UINT8, MetadataType.INT8, MetadataType.BOOL -> input.skipFully(1)
|
||||
MetadataType.UINT16, MetadataType.INT16 -> input.skipFully(2)
|
||||
MetadataType.UINT32, MetadataType.INT32, MetadataType.FLOAT32 -> input.skipFully(4)
|
||||
MetadataType.UINT64, MetadataType.INT64, MetadataType.FLOAT64 -> input.skipFully(8)
|
||||
MetadataType.STRING -> {
|
||||
val len = readLittleLong(input); input.skipFully(len)
|
||||
}
|
||||
MetadataType.ARRAY -> {
|
||||
val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
|
||||
val len = readLittleLong(input)
|
||||
repeat(len.toInt()) { skipValue(input, elemType) } // recursive skip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Helper: Read an 8-byte little-endian unsigned value and return it as a signed Long (assuming it fits in 63 bits). */
|
||||
private fun readLittleLong(input: InputStream): Long {
|
||||
val bytes = ByteArray(8)
|
||||
input.readFully(bytes)
|
||||
|
||||
// Combine 8 bytes into a 64-bit value (Little Endian).
|
||||
// Note: If the value exceeds Long.MAX_VALUE (bit 63 is 1), this will produce a negative Long (two's complement).
|
||||
// In our context (lengths/counts), such extremely large values are not expected.
|
||||
return (bytes[7].toLong() and 0xFFL shl 56) or
|
||||
(bytes[6].toLong() and 0xFFL shl 48) or
|
||||
(bytes[5].toLong() and 0xFFL shl 40) or
|
||||
(bytes[4].toLong() and 0xFFL shl 32) or
|
||||
(bytes[3].toLong() and 0xFFL shl 24) or
|
||||
(bytes[2].toLong() and 0xFFL shl 16) or
|
||||
(bytes[1].toLong() and 0xFFL shl 8) or
|
||||
(bytes[0].toLong() and 0xFFL)
|
||||
}
|
||||
|
||||
/** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */
|
||||
private fun readString(input: InputStream): String {
|
||||
// Read 8-byte little-endian length (number of bytes in the string).
|
||||
val len = readLittleLong(input)
|
||||
if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len")
|
||||
|
||||
// Read the UTF-8 bytes of the given length.
|
||||
val buf = ByteArray(len.toInt())
|
||||
if (buf.isNotEmpty()) input.readFully(buf)
|
||||
return String(buf, Charsets.UTF_8)
|
||||
}
|
||||
|
||||
/** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */
|
||||
private fun littleEndianBytesToInt(bytes: ByteArray): Int {
|
||||
// Note: assumes bytes length is 4.
|
||||
return (bytes[3].toInt() and 0xFF shl 24) or
|
||||
(bytes[2].toInt() and 0xFF shl 16) or
|
||||
(bytes[1].toInt() and 0xFF shl 8) or
|
||||
(bytes[0].toInt() and 0xFF)
|
||||
}
|
||||
|
||||
/**
|
||||
* Robust skip that works the same on JDK 11 and Android’s desugared runtime.
|
||||
*
|
||||
* @param n Number of bytes to advance in the stream.
|
||||
* @throws IOException on premature EOF.
|
||||
*/
|
||||
private fun InputStream.skipFully(n: Long) {
|
||||
var remaining = n
|
||||
val scratch = ByteArray(8192) // read‑and‑toss buffer
|
||||
while (remaining > 0) {
|
||||
val skipped = skip(remaining)
|
||||
when {
|
||||
skipped > 0 -> remaining -= skipped // normal fast path
|
||||
skipped == 0L -> {
|
||||
// fallback: read and discard
|
||||
val read = read(scratch, 0, minOf(remaining, scratch.size.toLong()).toInt())
|
||||
if (read == -1) throw IOException("EOF while skipping $n bytes")
|
||||
remaining -= read
|
||||
}
|
||||
else -> throw IOException("Skip returned negative value")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extension that keeps reading until the requested number of bytes are filled.
|
||||
* Falls back to `read()` when `skip()` returns 0, which happens on some Android
|
||||
* streams.
|
||||
*
|
||||
* @param buf Destination buffer.
|
||||
* @param len Number of bytes to fill (defaults to `buf.size`).
|
||||
* @throws IOException on premature EOF.
|
||||
*/
|
||||
private fun InputStream.readFully(buf: ByteArray, len: Int = buf.size) {
|
||||
var off = 0
|
||||
while (off < len) {
|
||||
val n = read(buf, off, len - off)
|
||||
if (n == -1) throw IOException("EOF after $off of $len bytes")
|
||||
off += n
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Read EXACTLY `n` bytes or throw – never returns a partially‑filled array.
|
||||
* This is used for small fixed‑length reads (e.g. 4‑byte type codes).
|
||||
*
|
||||
* @throws IOException on premature EOF.
|
||||
*/
|
||||
private fun InputStream.readNBytesExact(n: Int): ByteArray {
|
||||
val buf = ByteArray(n)
|
||||
if (read(buf) != n) throw IOException("Unexpected EOF")
|
||||
return buf
|
||||
}
|
||||
}
|
||||
|
|
@ -1,58 +0,0 @@
|
|||
package com.example.llama.revamp.util
|
||||
|
||||
import java.util.Locale
|
||||
|
||||
|
||||
/**
|
||||
* Convert bytes into human readable sizes
|
||||
*/
|
||||
fun formatSize(sizeInBytes: Long) = when {
|
||||
sizeInBytes >= 1_000_000_000 -> {
|
||||
val sizeInGb = sizeInBytes / 1_000_000_000.0
|
||||
String.format(Locale.getDefault(), "%.2f GB", sizeInGb)
|
||||
}
|
||||
sizeInBytes >= 1_000_000 -> {
|
||||
val sizeInMb = sizeInBytes / 1_000_000.0
|
||||
String.format(Locale.getDefault(), "%.2f MB", sizeInMb)
|
||||
}
|
||||
else -> {
|
||||
val sizeInKb = sizeInBytes / 1_000.0
|
||||
String.format(Locale.getDefault(), "%.2f KB", sizeInKb)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to extract parameters by looking for patterns like 7B, 13B, etc.
|
||||
*/
|
||||
fun extractParametersFromFilename(filename: String): String? =
|
||||
Regex("([0-9]+(\\.[0-9]+)?)[bB]").find(filename)?.value?.uppercase()
|
||||
|
||||
/**
|
||||
* Try to extract quantization by looking for patterns like Q4_0, Q5_K_M, etc.
|
||||
*/
|
||||
fun extractQuantizationFromFilename(filename: String) =
|
||||
listOf(
|
||||
Regex("[qQ][0-9]_[0-9]"),
|
||||
Regex("[qQ][0-9]_[kK]_[mM]"),
|
||||
Regex("[qQ][0-9]_[kK]"),
|
||||
Regex("[qQ][0-9][fF](16|32)")
|
||||
).firstNotNullOfOrNull {
|
||||
it.find(filename)?.value?.uppercase()
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to extract model type (Llama, Mistral, etc.)
|
||||
*
|
||||
* TODO-han.yin: Replace with GGUF parsing, also to be moved into the util object
|
||||
*/
|
||||
fun extractModelTypeFromFilename(filename: String): String? {
|
||||
val lowerFilename = filename.lowercase()
|
||||
return listOf("llama", "mistral", "phi", "qwen", "falcon", "mpt")
|
||||
.firstNotNullOfOrNull { type ->
|
||||
if (lowerFilename.contains(type)) {
|
||||
type.replaceFirstChar {
|
||||
if (it.isLowerCase()) it.titlecase(Locale.ROOT) else it.toString()
|
||||
}
|
||||
} else { null }
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
package com.example.llama.revamp.util
|
||||
|
||||
import java.util.Locale
|
||||
|
||||
|
||||
@Deprecated("Use GgufMetadataReader instead!")
|
||||
class NaiveMetadataExtractor private constructor() {
|
||||
/**
|
||||
* Try to extract parameters by looking for patterns like 7B, 13B, etc.
|
||||
*/
|
||||
fun extractParametersFromFilename(filename: String): String? =
|
||||
Regex("([0-9]+(\\.[0-9]+)?)[bB]").find(filename)?.value?.uppercase()
|
||||
|
||||
/**
|
||||
* Try to extract quantization by looking for patterns like Q4_0, Q5_K_M, etc.
|
||||
*/
|
||||
fun extractQuantizationFromFilename(filename: String) =
|
||||
listOf(
|
||||
Regex("[qQ][0-9]_[0-9]"),
|
||||
Regex("[qQ][0-9]_[kK]_[mM]"),
|
||||
Regex("[qQ][0-9]_[kK]"),
|
||||
Regex("[qQ][0-9][fF](16|32)")
|
||||
).firstNotNullOfOrNull {
|
||||
it.find(filename)?.value?.uppercase()
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to extract model type (Llama, Mistral, etc.)
|
||||
*/
|
||||
fun extractModelTypeFromFilename(filename: String): String? =
|
||||
filename.lowercase().let { lowerFilename ->
|
||||
listOf("llama", "mistral", "phi", "qwen", "falcon", "mpt")
|
||||
.firstNotNullOfOrNull { type ->
|
||||
if (lowerFilename.contains(type)) {
|
||||
type.replaceFirstChar {
|
||||
if (it.isLowerCase()) it.titlecase(Locale.ROOT) else it.toString()
|
||||
}
|
||||
} else { null }
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue