lib: refactor the GgufMetadataReader to take InputStream instead of absolute path as argument

This commit is contained in:
Han Yin 2025-08-06 13:05:34 -07:00
parent 2c9b1d37e0
commit 46e82c09f6
3 changed files with 48 additions and 52 deletions

View File

@ -208,9 +208,10 @@ class ModelRepositoryImpl @Inject constructor(
// Extract GGUF metadata if possible
val metadata = try {
val filePath = modelFile.absolutePath
Log.i(TAG, "Extracting GGUF Metadata from $filePath")
GgufMetadata.fromDomain(ggufMetadataReader.readStructuredMetadata(filePath))
Log.i(TAG, "Extracting GGUF Metadata from ${modelFile.absolutePath}")
modelFile.inputStream().buffered().use {
GgufMetadata.fromDomain(ggufMetadataReader.readStructuredMetadata(it))
}
} catch (e: Exception) {
Log.e(TAG, "Cannot extract GGUF metadata: ${e.message}", e)
throw e

View File

@ -4,6 +4,7 @@ import android.content.Context
import android.llama.cpp.internal.gguf.GgufMetadataReaderImpl
import android.net.Uri
import java.io.IOException
import java.io.InputStream
/**
* Interface for reading GGUF metadata from model files.
@ -23,12 +24,12 @@ interface GgufMetadataReader {
/**
* Reads and parses GGUF metadata from the specified file path.
*
* @param path The absolute path to the GGUF file
* @param input the [InputStream] obtained from a readable file or content
* @return Structured metadata extracted from the file
* @throws IOException if file is damaged or cannot be read
* @throws InvalidFileFormatException if file format is invalid
*/
suspend fun readStructuredMetadata(path: String): GgufMetadata
suspend fun readStructuredMetadata(input: InputStream): GgufMetadata
companion object {
private val DEFAULT_SKIP_KEYS = setOf(

View File

@ -5,7 +5,6 @@ import android.llama.cpp.gguf.GgufMetadata
import android.llama.cpp.gguf.GgufMetadataReader
import android.llama.cpp.gguf.InvalidFileFormatException
import android.net.Uri
import java.io.File
import java.io.IOException
import java.io.InputStream
@ -79,11 +78,11 @@ internal class GgufMetadataReaderImpl(
context.contentResolver.openInputStream(uri)?.buffered()?.use { ensureMagic(it) } == true
/** Reads the 4byte magic; throws if magic ≠ "GGUF". */
private fun ensureMagic(input: InputStream): Boolean {
val magic = ByteArray(4)
if (input.read(magic) != 4) throw IOException("Not a valid file!")
return magic.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF"
}
private fun ensureMagic(input: InputStream): Boolean =
ByteArray(4).let {
if (input.read(it) != 4) throw IOException("Not a valid file!")
it.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF"
}
/**
* Highlevel entry point: parses a `.gguf` file on disk and returns the fully
@ -106,20 +105,18 @@ internal class GgufMetadataReaderImpl(
* @throws IOException if the file is not GGUF, the version is unsupported,
* or the metadata block is truncated / corrupt.
*/
override suspend fun readStructuredMetadata(path: String): GgufMetadata {
File(path).inputStream().buffered().use { input ->
// ── 1. header ──────────────────────────────────────────────────────────
// throws on mismatch
val version = ensureMagicAndVersion(input)
val tensorCount = readLittleLong(input)
val kvCount = readLittleLong(input)
override suspend fun readStructuredMetadata(input: InputStream): GgufMetadata {
// ── 1. header ──────────────────────────────────────────────────────────
// throws on mismatch
val version = ensureMagicAndVersion(input)
val tensorCount = readLittleLong(input)
val kvCount = readLittleLong(input)
// ── 2. metadata map (reuse our raw parser, but we need access to the stream) ──
val meta = readMetaMap(input, kvCount) // <String, MetadataValue>
// ── 2. metadata map (reuse our raw parser, but we need access to the stream) ──
val meta = readMetaMap(input, kvCount) // <String, MetadataValue>
// ── 3. build structured object ────────────────────────────────────────
return buildStructured(meta, version, tensorCount, kvCount)
}
// ── 3. build structured object ────────────────────────────────────────
return buildStructured(meta, version, tensorCount, kvCount)
}
/** Reads the 4byte magic + 4byte version; throws if magic ≠ "GGUF". */
@ -136,7 +133,7 @@ internal class GgufMetadataReaderImpl(
private fun readLEUInt32(input: InputStream): Int {
val b0 = input.read(); val b1 = input.read(); val b2 = input.read(); val b3 = input.read()
if (b3 == -1) throw IOException("Unexpected EOF while reading UInt32")
return (b3 and 0xFF shl 24) or
return (b3 and 0xFF shl 24) or
(b2 and 0xFF shl 16) or
(b1 and 0xFF shl 8) or
(b0 and 0xFF)
@ -153,19 +150,18 @@ internal class GgufMetadataReaderImpl(
* The function honours [skipKeys] and [arraySummariseThreshold] by invoking
* [skipValue] or [parseValue] accordingly.
*/
private fun readMetaMap(input: InputStream, kvCnt: Long): Map<String, MetadataValue> {
val map = mutableMapOf<String, MetadataValue>()
repeat(kvCnt.toInt()) {
val key = readString(input)
val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
if (key in skipKeys) {
skipValue(input, valueT)
} else {
map[key] = parseValue(input, valueT)
}
}
return map
}
private fun readMetaMap(input: InputStream, kvCnt: Long): Map<String, MetadataValue> =
mutableMapOf<String, MetadataValue>().apply {
repeat(kvCnt.toInt()) {
val key = readString(input)
val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
if (key in skipKeys) {
skipValue(input, valueT)
} else {
this[key] = parseValue(input, valueT)
}
}
}
/**
* Converts a flat [Map]<[String], [MetadataValue]> into the stronglytyped
@ -509,25 +505,25 @@ internal class GgufMetadataReaderImpl(
}
/** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */
private fun readString(input: InputStream): String {
private fun readString(input: InputStream): String =
// Read 8-byte little-endian length (number of bytes in the string).
val len = readLittleLong(input)
if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len")
readLittleLong(input).let { len ->
if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len")
// Read the UTF-8 bytes of the given length.
val buf = ByteArray(len.toInt())
if (buf.isNotEmpty()) input.readFully(buf)
return String(buf, Charsets.UTF_8)
}
// Read the UTF-8 bytes of the given length.
ByteArray(len.toInt()).let {
if (it.isNotEmpty()) input.readFully(it)
String(it, Charsets.UTF_8)
}
}
/** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */
private fun littleEndianBytesToInt(bytes: ByteArray): Int {
private fun littleEndianBytesToInt(bytes: ByteArray): Int =
// Note: assumes bytes length is 4.
return (bytes[3].toInt() and 0xFF shl 24) or
(bytes[3].toInt() and 0xFF shl 24) or
(bytes[2].toInt() and 0xFF shl 16) or
(bytes[1].toInt() and 0xFF shl 8) or
(bytes[0].toInt() and 0xFF)
}
/**
* Robust skip that works the same on JDK 11 and Androids desugared runtime.
@ -577,9 +573,7 @@ internal class GgufMetadataReaderImpl(
*
* @throws IOException on premature EOF.
*/
private fun InputStream.readNBytesExact(n: Int): ByteArray {
val buf = ByteArray(n)
if (read(buf) != n) throw IOException("Unexpected EOF")
return buf
private fun InputStream.readNBytesExact(n: Int) = ByteArray(n).also {
if (read(it) != n) throw IOException("Unexpected EOF")
}
}