lib: refactor the GgufMetadataReader to take InputStream instead of absolute path as argument

This commit is contained in:
Han Yin 2025-08-06 13:05:34 -07:00
parent 2c9b1d37e0
commit 46e82c09f6
3 changed files with 48 additions and 52 deletions

View File

@ -208,9 +208,10 @@ class ModelRepositoryImpl @Inject constructor(
// Extract GGUF metadata if possible // Extract GGUF metadata if possible
val metadata = try { val metadata = try {
val filePath = modelFile.absolutePath Log.i(TAG, "Extracting GGUF Metadata from ${modelFile.absolutePath}")
Log.i(TAG, "Extracting GGUF Metadata from $filePath") modelFile.inputStream().buffered().use {
GgufMetadata.fromDomain(ggufMetadataReader.readStructuredMetadata(filePath)) GgufMetadata.fromDomain(ggufMetadataReader.readStructuredMetadata(it))
}
} catch (e: Exception) { } catch (e: Exception) {
Log.e(TAG, "Cannot extract GGUF metadata: ${e.message}", e) Log.e(TAG, "Cannot extract GGUF metadata: ${e.message}", e)
throw e throw e

View File

@ -4,6 +4,7 @@ import android.content.Context
import android.llama.cpp.internal.gguf.GgufMetadataReaderImpl import android.llama.cpp.internal.gguf.GgufMetadataReaderImpl
import android.net.Uri import android.net.Uri
import java.io.IOException import java.io.IOException
import java.io.InputStream
/** /**
* Interface for reading GGUF metadata from model files. * Interface for reading GGUF metadata from model files.
@ -23,12 +24,12 @@ interface GgufMetadataReader {
/** /**
* Reads and parses GGUF metadata from the specified file path. * Reads and parses GGUF metadata from the specified file path.
* *
* @param path The absolute path to the GGUF file * @param input the [InputStream] obtained from a readable file or content
* @return Structured metadata extracted from the file * @return Structured metadata extracted from the file
* @throws IOException if file is damaged or cannot be read * @throws IOException if file is damaged or cannot be read
* @throws InvalidFileFormatException if file format is invalid * @throws InvalidFileFormatException if file format is invalid
*/ */
suspend fun readStructuredMetadata(path: String): GgufMetadata suspend fun readStructuredMetadata(input: InputStream): GgufMetadata
companion object { companion object {
private val DEFAULT_SKIP_KEYS = setOf( private val DEFAULT_SKIP_KEYS = setOf(

View File

@ -5,7 +5,6 @@ import android.llama.cpp.gguf.GgufMetadata
import android.llama.cpp.gguf.GgufMetadataReader import android.llama.cpp.gguf.GgufMetadataReader
import android.llama.cpp.gguf.InvalidFileFormatException import android.llama.cpp.gguf.InvalidFileFormatException
import android.net.Uri import android.net.Uri
import java.io.File
import java.io.IOException import java.io.IOException
import java.io.InputStream import java.io.InputStream
@ -79,10 +78,10 @@ internal class GgufMetadataReaderImpl(
context.contentResolver.openInputStream(uri)?.buffered()?.use { ensureMagic(it) } == true context.contentResolver.openInputStream(uri)?.buffered()?.use { ensureMagic(it) } == true
/** Reads the 4byte magic; throws if magic ≠ "GGUF". */ /** Reads the 4byte magic; throws if magic ≠ "GGUF". */
private fun ensureMagic(input: InputStream): Boolean { private fun ensureMagic(input: InputStream): Boolean =
val magic = ByteArray(4) ByteArray(4).let {
if (input.read(magic) != 4) throw IOException("Not a valid file!") if (input.read(it) != 4) throw IOException("Not a valid file!")
return magic.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF" it.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF"
} }
/** /**
@ -106,8 +105,7 @@ internal class GgufMetadataReaderImpl(
* @throws IOException if the file is not GGUF, the version is unsupported, * @throws IOException if the file is not GGUF, the version is unsupported,
* or the metadata block is truncated / corrupt. * or the metadata block is truncated / corrupt.
*/ */
override suspend fun readStructuredMetadata(path: String): GgufMetadata { override suspend fun readStructuredMetadata(input: InputStream): GgufMetadata {
File(path).inputStream().buffered().use { input ->
// ── 1. header ────────────────────────────────────────────────────────── // ── 1. header ──────────────────────────────────────────────────────────
// throws on mismatch // throws on mismatch
val version = ensureMagicAndVersion(input) val version = ensureMagicAndVersion(input)
@ -120,7 +118,6 @@ internal class GgufMetadataReaderImpl(
// ── 3. build structured object ──────────────────────────────────────── // ── 3. build structured object ────────────────────────────────────────
return buildStructured(meta, version, tensorCount, kvCount) return buildStructured(meta, version, tensorCount, kvCount)
} }
}
/** Reads the 4byte magic + 4byte version; throws if magic ≠ "GGUF". */ /** Reads the 4byte magic + 4byte version; throws if magic ≠ "GGUF". */
private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion { private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion {
@ -153,18 +150,17 @@ internal class GgufMetadataReaderImpl(
* The function honours [skipKeys] and [arraySummariseThreshold] by invoking * The function honours [skipKeys] and [arraySummariseThreshold] by invoking
* [skipValue] or [parseValue] accordingly. * [skipValue] or [parseValue] accordingly.
*/ */
private fun readMetaMap(input: InputStream, kvCnt: Long): Map<String, MetadataValue> { private fun readMetaMap(input: InputStream, kvCnt: Long): Map<String, MetadataValue> =
val map = mutableMapOf<String, MetadataValue>() mutableMapOf<String, MetadataValue>().apply {
repeat(kvCnt.toInt()) { repeat(kvCnt.toInt()) {
val key = readString(input) val key = readString(input)
val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4))) val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
if (key in skipKeys) { if (key in skipKeys) {
skipValue(input, valueT) skipValue(input, valueT)
} else { } else {
map[key] = parseValue(input, valueT) this[key] = parseValue(input, valueT)
} }
} }
return map
} }
/** /**
@ -509,25 +505,25 @@ internal class GgufMetadataReaderImpl(
} }
/** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */ /** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */
private fun readString(input: InputStream): String { private fun readString(input: InputStream): String =
// Read 8-byte little-endian length (number of bytes in the string). // Read 8-byte little-endian length (number of bytes in the string).
val len = readLittleLong(input) readLittleLong(input).let { len ->
if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len") if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len")
// Read the UTF-8 bytes of the given length. // Read the UTF-8 bytes of the given length.
val buf = ByteArray(len.toInt()) ByteArray(len.toInt()).let {
if (buf.isNotEmpty()) input.readFully(buf) if (it.isNotEmpty()) input.readFully(it)
return String(buf, Charsets.UTF_8) String(it, Charsets.UTF_8)
}
} }
/** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */ /** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */
private fun littleEndianBytesToInt(bytes: ByteArray): Int { private fun littleEndianBytesToInt(bytes: ByteArray): Int =
// Note: assumes bytes length is 4. // Note: assumes bytes length is 4.
return (bytes[3].toInt() and 0xFF shl 24) or (bytes[3].toInt() and 0xFF shl 24) or
(bytes[2].toInt() and 0xFF shl 16) or (bytes[2].toInt() and 0xFF shl 16) or
(bytes[1].toInt() and 0xFF shl 8) or (bytes[1].toInt() and 0xFF shl 8) or
(bytes[0].toInt() and 0xFF) (bytes[0].toInt() and 0xFF)
}
/** /**
* Robust skip that works the same on JDK 11 and Androids desugared runtime. * Robust skip that works the same on JDK 11 and Androids desugared runtime.
@ -577,9 +573,7 @@ internal class GgufMetadataReaderImpl(
* *
* @throws IOException on premature EOF. * @throws IOException on premature EOF.
*/ */
private fun InputStream.readNBytesExact(n: Int): ByteArray { private fun InputStream.readNBytesExact(n: Int) = ByteArray(n).also {
val buf = ByteArray(n) if (read(it) != n) throw IOException("Unexpected EOF")
if (read(buf) != n) throw IOException("Unexpected EOF")
return buf
} }
} }