lib: refactor the GgufMetadataReader to take InputStream instead of absolute path as argument
This commit is contained in:
parent
2c9b1d37e0
commit
46e82c09f6
|
|
@ -208,9 +208,10 @@ class ModelRepositoryImpl @Inject constructor(
|
|||
|
||||
// Extract GGUF metadata if possible
|
||||
val metadata = try {
|
||||
val filePath = modelFile.absolutePath
|
||||
Log.i(TAG, "Extracting GGUF Metadata from $filePath")
|
||||
GgufMetadata.fromDomain(ggufMetadataReader.readStructuredMetadata(filePath))
|
||||
Log.i(TAG, "Extracting GGUF Metadata from ${modelFile.absolutePath}")
|
||||
modelFile.inputStream().buffered().use {
|
||||
GgufMetadata.fromDomain(ggufMetadataReader.readStructuredMetadata(it))
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Cannot extract GGUF metadata: ${e.message}", e)
|
||||
throw e
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import android.content.Context
|
|||
import android.llama.cpp.internal.gguf.GgufMetadataReaderImpl
|
||||
import android.net.Uri
|
||||
import java.io.IOException
|
||||
import java.io.InputStream
|
||||
|
||||
/**
|
||||
* Interface for reading GGUF metadata from model files.
|
||||
|
|
@ -23,12 +24,12 @@ interface GgufMetadataReader {
|
|||
/**
|
||||
* Reads and parses GGUF metadata from the specified file path.
|
||||
*
|
||||
* @param path The absolute path to the GGUF file
|
||||
* @param input the [InputStream] obtained from a readable file or content
|
||||
* @return Structured metadata extracted from the file
|
||||
* @throws IOException if file is damaged or cannot be read
|
||||
* @throws InvalidFileFormatException if file format is invalid
|
||||
*/
|
||||
suspend fun readStructuredMetadata(path: String): GgufMetadata
|
||||
suspend fun readStructuredMetadata(input: InputStream): GgufMetadata
|
||||
|
||||
companion object {
|
||||
private val DEFAULT_SKIP_KEYS = setOf(
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import android.llama.cpp.gguf.GgufMetadata
|
|||
import android.llama.cpp.gguf.GgufMetadataReader
|
||||
import android.llama.cpp.gguf.InvalidFileFormatException
|
||||
import android.net.Uri
|
||||
import java.io.File
|
||||
import java.io.IOException
|
||||
import java.io.InputStream
|
||||
|
||||
|
|
@ -79,11 +78,11 @@ internal class GgufMetadataReaderImpl(
|
|||
context.contentResolver.openInputStream(uri)?.buffered()?.use { ensureMagic(it) } == true
|
||||
|
||||
/** Reads the 4‑byte magic; throws if magic ≠ "GGUF". */
|
||||
private fun ensureMagic(input: InputStream): Boolean {
|
||||
val magic = ByteArray(4)
|
||||
if (input.read(magic) != 4) throw IOException("Not a valid file!")
|
||||
return magic.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF"
|
||||
}
|
||||
private fun ensureMagic(input: InputStream): Boolean =
|
||||
ByteArray(4).let {
|
||||
if (input.read(it) != 4) throw IOException("Not a valid file!")
|
||||
it.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF"
|
||||
}
|
||||
|
||||
/**
|
||||
* High‑level entry point: parses a `.gguf` file on disk and returns the fully
|
||||
|
|
@ -106,20 +105,18 @@ internal class GgufMetadataReaderImpl(
|
|||
* @throws IOException if the file is not GGUF, the version is unsupported,
|
||||
* or the metadata block is truncated / corrupt.
|
||||
*/
|
||||
override suspend fun readStructuredMetadata(path: String): GgufMetadata {
|
||||
File(path).inputStream().buffered().use { input ->
|
||||
// ── 1. header ──────────────────────────────────────────────────────────
|
||||
// throws on mismatch
|
||||
val version = ensureMagicAndVersion(input)
|
||||
val tensorCount = readLittleLong(input)
|
||||
val kvCount = readLittleLong(input)
|
||||
override suspend fun readStructuredMetadata(input: InputStream): GgufMetadata {
|
||||
// ── 1. header ──────────────────────────────────────────────────────────
|
||||
// throws on mismatch
|
||||
val version = ensureMagicAndVersion(input)
|
||||
val tensorCount = readLittleLong(input)
|
||||
val kvCount = readLittleLong(input)
|
||||
|
||||
// ── 2. metadata map (reuse our raw parser, but we need access to the stream) ──
|
||||
val meta = readMetaMap(input, kvCount) // <String, MetadataValue>
|
||||
// ── 2. metadata map (reuse our raw parser, but we need access to the stream) ──
|
||||
val meta = readMetaMap(input, kvCount) // <String, MetadataValue>
|
||||
|
||||
// ── 3. build structured object ────────────────────────────────────────
|
||||
return buildStructured(meta, version, tensorCount, kvCount)
|
||||
}
|
||||
// ── 3. build structured object ────────────────────────────────────────
|
||||
return buildStructured(meta, version, tensorCount, kvCount)
|
||||
}
|
||||
|
||||
/** Reads the 4‑byte magic + 4‑byte version; throws if magic ≠ "GGUF". */
|
||||
|
|
@ -136,7 +133,7 @@ internal class GgufMetadataReaderImpl(
|
|||
private fun readLEUInt32(input: InputStream): Int {
|
||||
val b0 = input.read(); val b1 = input.read(); val b2 = input.read(); val b3 = input.read()
|
||||
if (b3 == -1) throw IOException("Unexpected EOF while reading UInt32")
|
||||
return (b3 and 0xFF shl 24) or
|
||||
return (b3 and 0xFF shl 24) or
|
||||
(b2 and 0xFF shl 16) or
|
||||
(b1 and 0xFF shl 8) or
|
||||
(b0 and 0xFF)
|
||||
|
|
@ -153,19 +150,18 @@ internal class GgufMetadataReaderImpl(
|
|||
* The function honours [skipKeys] and [arraySummariseThreshold] by invoking
|
||||
* [skipValue] or [parseValue] accordingly.
|
||||
*/
|
||||
private fun readMetaMap(input: InputStream, kvCnt: Long): Map<String, MetadataValue> {
|
||||
val map = mutableMapOf<String, MetadataValue>()
|
||||
repeat(kvCnt.toInt()) {
|
||||
val key = readString(input)
|
||||
val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
|
||||
if (key in skipKeys) {
|
||||
skipValue(input, valueT)
|
||||
} else {
|
||||
map[key] = parseValue(input, valueT)
|
||||
}
|
||||
}
|
||||
return map
|
||||
}
|
||||
private fun readMetaMap(input: InputStream, kvCnt: Long): Map<String, MetadataValue> =
|
||||
mutableMapOf<String, MetadataValue>().apply {
|
||||
repeat(kvCnt.toInt()) {
|
||||
val key = readString(input)
|
||||
val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
|
||||
if (key in skipKeys) {
|
||||
skipValue(input, valueT)
|
||||
} else {
|
||||
this[key] = parseValue(input, valueT)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a flat [Map]<[String], [MetadataValue]> into the strongly‑typed
|
||||
|
|
@ -509,25 +505,25 @@ internal class GgufMetadataReaderImpl(
|
|||
}
|
||||
|
||||
/** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */
|
||||
private fun readString(input: InputStream): String {
|
||||
private fun readString(input: InputStream): String =
|
||||
// Read 8-byte little-endian length (number of bytes in the string).
|
||||
val len = readLittleLong(input)
|
||||
if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len")
|
||||
readLittleLong(input).let { len ->
|
||||
if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len")
|
||||
|
||||
// Read the UTF-8 bytes of the given length.
|
||||
val buf = ByteArray(len.toInt())
|
||||
if (buf.isNotEmpty()) input.readFully(buf)
|
||||
return String(buf, Charsets.UTF_8)
|
||||
}
|
||||
// Read the UTF-8 bytes of the given length.
|
||||
ByteArray(len.toInt()).let {
|
||||
if (it.isNotEmpty()) input.readFully(it)
|
||||
String(it, Charsets.UTF_8)
|
||||
}
|
||||
}
|
||||
|
||||
/** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */
|
||||
private fun littleEndianBytesToInt(bytes: ByteArray): Int {
|
||||
private fun littleEndianBytesToInt(bytes: ByteArray): Int =
|
||||
// Note: assumes bytes length is 4.
|
||||
return (bytes[3].toInt() and 0xFF shl 24) or
|
||||
(bytes[3].toInt() and 0xFF shl 24) or
|
||||
(bytes[2].toInt() and 0xFF shl 16) or
|
||||
(bytes[1].toInt() and 0xFF shl 8) or
|
||||
(bytes[0].toInt() and 0xFF)
|
||||
}
|
||||
|
||||
/**
|
||||
* Robust skip that works the same on JDK 11 and Android’s desugared runtime.
|
||||
|
|
@ -577,9 +573,7 @@ internal class GgufMetadataReaderImpl(
|
|||
*
|
||||
* @throws IOException on premature EOF.
|
||||
*/
|
||||
private fun InputStream.readNBytesExact(n: Int): ByteArray {
|
||||
val buf = ByteArray(n)
|
||||
if (read(buf) != n) throw IOException("Unexpected EOF")
|
||||
return buf
|
||||
private fun InputStream.readNBytesExact(n: Int) = ByteArray(n).also {
|
||||
if (read(it) != n) throw IOException("Unexpected EOF")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue