From 381994234c706f815ba763a0d553ee25d4b8823f Mon Sep 17 00:00:00 2001 From: Han Yin Date: Mon, 4 Aug 2025 08:18:38 -0700 Subject: [PATCH] lib: read & validate the magic number from the picked source file before executing the import --- .../llama/data/repo/ModelRepository.kt | 8 +++++- .../viewmodel/ModelsManagementViewModel.kt | 11 ++++++-- .../llama/cpp/gguf/GgufMetadataReader.kt | 18 +++++++++++-- .../internal/gguf/GgufMetadataReaderImpl.kt | 27 +++++++++++++++---- 4 files changed, 54 insertions(+), 10 deletions(-) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/repo/ModelRepository.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/repo/ModelRepository.kt index 0a204950c4..0750e9670f 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/data/repo/ModelRepository.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/repo/ModelRepository.kt @@ -2,6 +2,7 @@ package com.example.llama.data.repo import android.content.Context import android.llama.cpp.gguf.GgufMetadataReader +import android.llama.cpp.gguf.InvalidFileFormatException import android.net.Uri import android.os.StatFs import android.util.Log @@ -171,9 +172,15 @@ class ModelRepositoryImpl @Inject constructor( throw IllegalStateException("Another import is already in progress!") } + // Check file info val fileInfo = localFileDataSource.getFileInfo(uri) val fileSize = size ?: fileInfo?.size ?: throw FileNotFoundException("File size N/A") val fileName = name ?: fileInfo?.name ?: throw FileNotFoundException("File name N/A") + if (!ggufMetadataReader.ensureSourceFileFormat(context, uri)) { + throw InvalidFileFormatException() + } + + // Check for enough storage if (!hasEnoughSpaceForImport(fileSize)) { throw InsufficientStorageException( "Not enough storage space! " + @@ -182,7 +189,6 @@ class ModelRepositoryImpl @Inject constructor( ) } val modelFile = File(modelsDir, fileName) - importJob = coroutineContext[Job] currentModelFile = modelFile diff --git a/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt index 19014b3148..79ed1ea0ec 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt @@ -6,6 +6,7 @@ import android.content.Context import android.content.Context.RECEIVER_EXPORTED import android.content.Intent import android.content.IntentFilter +import android.llama.cpp.gguf.InvalidFileFormatException import android.net.Uri import android.util.Log import androidx.lifecycle.ViewModel @@ -208,13 +209,18 @@ class ModelsManagementViewModel @Inject constructor( _managementState.value = Importation.Importing(progress, fileName, fileSize) } _managementState.value = Importation.Success(model) + } catch (_: InvalidFileFormatException) { + _managementState.value = Importation.Error( + message = "Not a valid GGUF model!", + ) } catch (e: InsufficientStorageException) { _managementState.value = Importation.Error( - message = e.message ?: "Insufficient storage space to import $uri", + message = e.message ?: "Insufficient storage space to import $fileName", ) } catch (e: Exception) { + Log.e(TAG, "Unknown exception importing $fileName", e) _managementState.value = Importation.Error( - message = e.message ?: "Unknown error importing $uri", + message = e.message ?: "Unknown error importing $fileName", ) } } @@ -376,6 +382,7 @@ sealed class ModelManagementState { data class Confirming(val uri: Uri, val fileName: String, val fileSize: Long) : Importation() data class Importing(val progress: Float = 0f, val fileName: String, val fileSize: Long, val isCancelling: Boolean = false) : Importation() data class Success(val model: ModelInfo) : Importation() + // TODO-han.yin: Add an optional explanation URL for more info! data class Error(val message: String) : Importation() } diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/gguf/GgufMetadataReader.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/gguf/GgufMetadataReader.kt index bfc590fff5..995d02774a 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/gguf/GgufMetadataReader.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/gguf/GgufMetadataReader.kt @@ -1,6 +1,8 @@ package android.llama.cpp.gguf +import android.content.Context import android.llama.cpp.internal.gguf.GgufMetadataReaderImpl +import android.net.Uri import java.io.IOException /** @@ -8,13 +10,23 @@ import java.io.IOException * Use `GgufMetadataReader.create()` to get an instance. */ interface GgufMetadataReader { + /** + * Reads the magic number from the specified file path. + * + * @param context Context for obtaining ContentResolver + * @param uri Uri to the GGUF file provided by ContentProvider + * @return true if file is valid GGUF, otherwise false + * @throws InvalidFileFormatException if file format is invalid + */ + suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean + /** * Reads and parses GGUF metadata from the specified file path. * * @param path The absolute path to the GGUF file * @return Structured metadata extracted from the file - * @throws IOException if file cannot be read - * @throws IllegalArgumentException if file format is invalid + * @throws IOException if file is damaged or cannot be read + * @throws InvalidFileFormatException if file format is invalid */ suspend fun readStructuredMetadata(path: String): GgufMetadata @@ -50,3 +62,5 @@ interface GgufMetadataReader { ) } } + +class InvalidFileFormatException : IOException() diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/gguf/GgufMetadataReaderImpl.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/gguf/GgufMetadataReaderImpl.kt index 44944d3223..f20083593f 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/gguf/GgufMetadataReaderImpl.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/gguf/GgufMetadataReaderImpl.kt @@ -1,7 +1,10 @@ package android.llama.cpp.internal.gguf +import android.content.Context import android.llama.cpp.gguf.GgufMetadata import android.llama.cpp.gguf.GgufMetadataReader +import android.llama.cpp.gguf.InvalidFileFormatException +import android.net.Uri import java.io.File import java.io.IOException import java.io.InputStream @@ -25,7 +28,7 @@ internal class GgufMetadataReaderImpl( UINT32(4), INT32(5), FLOAT32(6), BOOL(7), STRING(8), ARRAY(9), UINT64(10), INT64(11), FLOAT64(12); companion object { - private val codeMap = values().associateBy(MetadataType::code) + private val codeMap = entries.associateBy(MetadataType::code) fun fromCode(code: Int): MetadataType = codeMap[code] ?: throw IOException("Unknown metadata value type code: $code") } @@ -65,6 +68,23 @@ internal class GgufMetadataReaderImpl( is MetadataValue.ArrayVal -> elements.map { it.toPrimitive() } } + /** + * Reads the magic number from the specified file path. + * + * @param context Context for obtaining ContentResolver + * @param uri Uri to the GGUF file provided by ContentProvider + * @return true if file is valid GGUF, otherwise false + */ + override suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean = + context.contentResolver.openInputStream(uri)?.buffered()?.use { ensureMagic(it) } == true + + /** Reads the 4‑byte magic; throws if magic ≠ "GGUF". */ + private fun ensureMagic(input: InputStream): Boolean { + val magic = ByteArray(4) + if (input.read(magic) != 4) throw IOException("Not a valid file!") + return magic.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF" + } + /** * High‑level entry point: parses a `.gguf` file on disk and returns the fully * populated [GgufMetadata] tree. @@ -104,10 +124,7 @@ internal class GgufMetadataReaderImpl( /** Reads the 4‑byte magic + 4‑byte version; throws if magic ≠ "GGUF". */ private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion { - val magic = ByteArray(4) - if (input.read(magic) != 4) throw IOException("File too short (no magic)") - if (!magic.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46))) // "GGUF" - throw IOException("Not a GGUF file (bad magic)") + if (!ensureMagic(input)) throw InvalidFileFormatException() return GgufMetadata.GgufVersion.fromCode(readLEUInt32(input)) }