lib: refactor the GgufMetadataReader to take InputStream instead of absolute path as argument

This commit is contained in:
Han Yin 2025-08-06 13:05:34 -07:00
parent 2c9b1d37e0
commit 46e82c09f6
3 changed files with 48 additions and 52 deletions

View File

@ -208,9 +208,10 @@ class ModelRepositoryImpl @Inject constructor(
// Extract GGUF metadata if possible // Extract GGUF metadata if possible
val metadata = try { val metadata = try {
val filePath = modelFile.absolutePath Log.i(TAG, "Extracting GGUF Metadata from ${modelFile.absolutePath}")
Log.i(TAG, "Extracting GGUF Metadata from $filePath") modelFile.inputStream().buffered().use {
GgufMetadata.fromDomain(ggufMetadataReader.readStructuredMetadata(filePath)) GgufMetadata.fromDomain(ggufMetadataReader.readStructuredMetadata(it))
}
} catch (e: Exception) { } catch (e: Exception) {
Log.e(TAG, "Cannot extract GGUF metadata: ${e.message}", e) Log.e(TAG, "Cannot extract GGUF metadata: ${e.message}", e)
throw e throw e

View File

@ -4,6 +4,7 @@ import android.content.Context
import android.llama.cpp.internal.gguf.GgufMetadataReaderImpl import android.llama.cpp.internal.gguf.GgufMetadataReaderImpl
import android.net.Uri import android.net.Uri
import java.io.IOException import java.io.IOException
import java.io.InputStream
/** /**
* Interface for reading GGUF metadata from model files. * Interface for reading GGUF metadata from model files.
@ -23,12 +24,12 @@ interface GgufMetadataReader {
/** /**
* Reads and parses GGUF metadata from the specified file path. * Reads and parses GGUF metadata from the specified file path.
* *
* @param path The absolute path to the GGUF file * @param input the [InputStream] obtained from a readable file or content
* @return Structured metadata extracted from the file * @return Structured metadata extracted from the file
* @throws IOException if file is damaged or cannot be read * @throws IOException if file is damaged or cannot be read
* @throws InvalidFileFormatException if file format is invalid * @throws InvalidFileFormatException if file format is invalid
*/ */
suspend fun readStructuredMetadata(path: String): GgufMetadata suspend fun readStructuredMetadata(input: InputStream): GgufMetadata
companion object { companion object {
private val DEFAULT_SKIP_KEYS = setOf( private val DEFAULT_SKIP_KEYS = setOf(

View File

@ -5,7 +5,6 @@ import android.llama.cpp.gguf.GgufMetadata
import android.llama.cpp.gguf.GgufMetadataReader import android.llama.cpp.gguf.GgufMetadataReader
import android.llama.cpp.gguf.InvalidFileFormatException import android.llama.cpp.gguf.InvalidFileFormatException
import android.net.Uri import android.net.Uri
import java.io.File
import java.io.IOException import java.io.IOException
import java.io.InputStream import java.io.InputStream
@ -79,11 +78,11 @@ internal class GgufMetadataReaderImpl(
context.contentResolver.openInputStream(uri)?.buffered()?.use { ensureMagic(it) } == true context.contentResolver.openInputStream(uri)?.buffered()?.use { ensureMagic(it) } == true
/** Reads the 4byte magic; throws if magic ≠ "GGUF". */ /** Reads the 4byte magic; throws if magic ≠ "GGUF". */
private fun ensureMagic(input: InputStream): Boolean { private fun ensureMagic(input: InputStream): Boolean =
val magic = ByteArray(4) ByteArray(4).let {
if (input.read(magic) != 4) throw IOException("Not a valid file!") if (input.read(it) != 4) throw IOException("Not a valid file!")
return magic.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF" it.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF"
} }
/** /**
* Highlevel entry point: parses a `.gguf` file on disk and returns the fully * Highlevel entry point: parses a `.gguf` file on disk and returns the fully
@ -106,20 +105,18 @@ internal class GgufMetadataReaderImpl(
* @throws IOException if the file is not GGUF, the version is unsupported, * @throws IOException if the file is not GGUF, the version is unsupported,
* or the metadata block is truncated / corrupt. * or the metadata block is truncated / corrupt.
*/ */
override suspend fun readStructuredMetadata(path: String): GgufMetadata { override suspend fun readStructuredMetadata(input: InputStream): GgufMetadata {
File(path).inputStream().buffered().use { input -> // ── 1. header ──────────────────────────────────────────────────────────
// ── 1. header ────────────────────────────────────────────────────────── // throws on mismatch
// throws on mismatch val version = ensureMagicAndVersion(input)
val version = ensureMagicAndVersion(input) val tensorCount = readLittleLong(input)
val tensorCount = readLittleLong(input) val kvCount = readLittleLong(input)
val kvCount = readLittleLong(input)
// ── 2. metadata map (reuse our raw parser, but we need access to the stream) ── // ── 2. metadata map (reuse our raw parser, but we need access to the stream) ──
val meta = readMetaMap(input, kvCount) // <String, MetadataValue> val meta = readMetaMap(input, kvCount) // <String, MetadataValue>
// ── 3. build structured object ──────────────────────────────────────── // ── 3. build structured object ────────────────────────────────────────
return buildStructured(meta, version, tensorCount, kvCount) return buildStructured(meta, version, tensorCount, kvCount)
}
} }
/** Reads the 4byte magic + 4byte version; throws if magic ≠ "GGUF". */ /** Reads the 4byte magic + 4byte version; throws if magic ≠ "GGUF". */
@ -136,7 +133,7 @@ internal class GgufMetadataReaderImpl(
private fun readLEUInt32(input: InputStream): Int { private fun readLEUInt32(input: InputStream): Int {
val b0 = input.read(); val b1 = input.read(); val b2 = input.read(); val b3 = input.read() val b0 = input.read(); val b1 = input.read(); val b2 = input.read(); val b3 = input.read()
if (b3 == -1) throw IOException("Unexpected EOF while reading UInt32") if (b3 == -1) throw IOException("Unexpected EOF while reading UInt32")
return (b3 and 0xFF shl 24) or return (b3 and 0xFF shl 24) or
(b2 and 0xFF shl 16) or (b2 and 0xFF shl 16) or
(b1 and 0xFF shl 8) or (b1 and 0xFF shl 8) or
(b0 and 0xFF) (b0 and 0xFF)
@ -153,19 +150,18 @@ internal class GgufMetadataReaderImpl(
* The function honours [skipKeys] and [arraySummariseThreshold] by invoking * The function honours [skipKeys] and [arraySummariseThreshold] by invoking
* [skipValue] or [parseValue] accordingly. * [skipValue] or [parseValue] accordingly.
*/ */
private fun readMetaMap(input: InputStream, kvCnt: Long): Map<String, MetadataValue> { private fun readMetaMap(input: InputStream, kvCnt: Long): Map<String, MetadataValue> =
val map = mutableMapOf<String, MetadataValue>() mutableMapOf<String, MetadataValue>().apply {
repeat(kvCnt.toInt()) { repeat(kvCnt.toInt()) {
val key = readString(input) val key = readString(input)
val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4))) val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
if (key in skipKeys) { if (key in skipKeys) {
skipValue(input, valueT) skipValue(input, valueT)
} else { } else {
map[key] = parseValue(input, valueT) this[key] = parseValue(input, valueT)
} }
} }
return map }
}
/** /**
* Converts a flat [Map]<[String], [MetadataValue]> into the stronglytyped * Converts a flat [Map]<[String], [MetadataValue]> into the stronglytyped
@ -509,25 +505,25 @@ internal class GgufMetadataReaderImpl(
} }
/** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */ /** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */
private fun readString(input: InputStream): String { private fun readString(input: InputStream): String =
// Read 8-byte little-endian length (number of bytes in the string). // Read 8-byte little-endian length (number of bytes in the string).
val len = readLittleLong(input) readLittleLong(input).let { len ->
if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len") if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len")
// Read the UTF-8 bytes of the given length. // Read the UTF-8 bytes of the given length.
val buf = ByteArray(len.toInt()) ByteArray(len.toInt()).let {
if (buf.isNotEmpty()) input.readFully(buf) if (it.isNotEmpty()) input.readFully(it)
return String(buf, Charsets.UTF_8) String(it, Charsets.UTF_8)
} }
}
/** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */ /** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */
private fun littleEndianBytesToInt(bytes: ByteArray): Int { private fun littleEndianBytesToInt(bytes: ByteArray): Int =
// Note: assumes bytes length is 4. // Note: assumes bytes length is 4.
return (bytes[3].toInt() and 0xFF shl 24) or (bytes[3].toInt() and 0xFF shl 24) or
(bytes[2].toInt() and 0xFF shl 16) or (bytes[2].toInt() and 0xFF shl 16) or
(bytes[1].toInt() and 0xFF shl 8) or (bytes[1].toInt() and 0xFF shl 8) or
(bytes[0].toInt() and 0xFF) (bytes[0].toInt() and 0xFF)
}
/** /**
* Robust skip that works the same on JDK 11 and Androids desugared runtime. * Robust skip that works the same on JDK 11 and Androids desugared runtime.
@ -577,9 +573,7 @@ internal class GgufMetadataReaderImpl(
* *
* @throws IOException on premature EOF. * @throws IOException on premature EOF.
*/ */
private fun InputStream.readNBytesExact(n: Int): ByteArray { private fun InputStream.readNBytesExact(n: Int) = ByteArray(n).also {
val buf = ByteArray(n) if (read(it) != n) throw IOException("Unexpected EOF")
if (read(buf) != n) throw IOException("Unexpected EOF")
return buf
} }
} }