lib: read & validate the magic number from the picked source file before executing the import

This commit is contained in:
Han Yin 2025-08-04 08:18:38 -07:00
parent 7968216235
commit 381994234c
4 changed files with 54 additions and 10 deletions

View File

@ -2,6 +2,7 @@ package com.example.llama.data.repo
import android.content.Context
import android.llama.cpp.gguf.GgufMetadataReader
import android.llama.cpp.gguf.InvalidFileFormatException
import android.net.Uri
import android.os.StatFs
import android.util.Log
@ -171,9 +172,15 @@ class ModelRepositoryImpl @Inject constructor(
throw IllegalStateException("Another import is already in progress!")
}
// Check file info
val fileInfo = localFileDataSource.getFileInfo(uri)
val fileSize = size ?: fileInfo?.size ?: throw FileNotFoundException("File size N/A")
val fileName = name ?: fileInfo?.name ?: throw FileNotFoundException("File name N/A")
if (!ggufMetadataReader.ensureSourceFileFormat(context, uri)) {
throw InvalidFileFormatException()
}
// Check for enough storage
if (!hasEnoughSpaceForImport(fileSize)) {
throw InsufficientStorageException(
"Not enough storage space! " +
@ -182,7 +189,6 @@ class ModelRepositoryImpl @Inject constructor(
)
}
val modelFile = File(modelsDir, fileName)
importJob = coroutineContext[Job]
currentModelFile = modelFile

View File

@ -6,6 +6,7 @@ import android.content.Context
import android.content.Context.RECEIVER_EXPORTED
import android.content.Intent
import android.content.IntentFilter
import android.llama.cpp.gguf.InvalidFileFormatException
import android.net.Uri
import android.util.Log
import androidx.lifecycle.ViewModel
@ -208,13 +209,18 @@ class ModelsManagementViewModel @Inject constructor(
_managementState.value = Importation.Importing(progress, fileName, fileSize)
}
_managementState.value = Importation.Success(model)
} catch (_: InvalidFileFormatException) {
_managementState.value = Importation.Error(
message = "Not a valid GGUF model!",
)
} catch (e: InsufficientStorageException) {
_managementState.value = Importation.Error(
message = e.message ?: "Insufficient storage space to import $uri",
message = e.message ?: "Insufficient storage space to import $fileName",
)
} catch (e: Exception) {
Log.e(TAG, "Unknown exception importing $fileName", e)
_managementState.value = Importation.Error(
message = e.message ?: "Unknown error importing $uri",
message = e.message ?: "Unknown error importing $fileName",
)
}
}
@ -376,6 +382,7 @@ sealed class ModelManagementState {
data class Confirming(val uri: Uri, val fileName: String, val fileSize: Long) : Importation()
data class Importing(val progress: Float = 0f, val fileName: String, val fileSize: Long, val isCancelling: Boolean = false) : Importation()
data class Success(val model: ModelInfo) : Importation()
// TODO-han.yin: Add an optional explanation URL for more info!
data class Error(val message: String) : Importation()
}

View File

@ -1,6 +1,8 @@
package android.llama.cpp.gguf
import android.content.Context
import android.llama.cpp.internal.gguf.GgufMetadataReaderImpl
import android.net.Uri
import java.io.IOException
/**
@ -8,13 +10,23 @@ import java.io.IOException
* Use `GgufMetadataReader.create()` to get an instance.
*/
interface GgufMetadataReader {
/**
* Reads the magic number from the specified file path.
*
* @param context Context for obtaining ContentResolver
* @param uri Uri to the GGUF file provided by ContentProvider
* @return true if file is valid GGUF, otherwise false
* @throws InvalidFileFormatException if file format is invalid
*/
suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean
/**
* Reads and parses GGUF metadata from the specified file path.
*
* @param path The absolute path to the GGUF file
* @return Structured metadata extracted from the file
* @throws IOException if file cannot be read
* @throws IllegalArgumentException if file format is invalid
* @throws IOException if file is damaged or cannot be read
* @throws InvalidFileFormatException if file format is invalid
*/
suspend fun readStructuredMetadata(path: String): GgufMetadata
@ -50,3 +62,5 @@ interface GgufMetadataReader {
)
}
}
class InvalidFileFormatException : IOException()

View File

@ -1,7 +1,10 @@
package android.llama.cpp.internal.gguf
import android.content.Context
import android.llama.cpp.gguf.GgufMetadata
import android.llama.cpp.gguf.GgufMetadataReader
import android.llama.cpp.gguf.InvalidFileFormatException
import android.net.Uri
import java.io.File
import java.io.IOException
import java.io.InputStream
@ -25,7 +28,7 @@ internal class GgufMetadataReaderImpl(
UINT32(4), INT32(5), FLOAT32(6), BOOL(7),
STRING(8), ARRAY(9), UINT64(10), INT64(11), FLOAT64(12);
companion object {
private val codeMap = values().associateBy(MetadataType::code)
private val codeMap = entries.associateBy(MetadataType::code)
fun fromCode(code: Int): MetadataType = codeMap[code]
?: throw IOException("Unknown metadata value type code: $code")
}
@ -65,6 +68,23 @@ internal class GgufMetadataReaderImpl(
is MetadataValue.ArrayVal -> elements.map { it.toPrimitive() }
}
/**
* Reads the magic number from the specified file path.
*
* @param context Context for obtaining ContentResolver
* @param uri Uri to the GGUF file provided by ContentProvider
* @return true if file is valid GGUF, otherwise false
*/
override suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean =
context.contentResolver.openInputStream(uri)?.buffered()?.use { ensureMagic(it) } == true
/** Reads the 4byte magic; throws if magic ≠ "GGUF". */
private fun ensureMagic(input: InputStream): Boolean {
val magic = ByteArray(4)
if (input.read(magic) != 4) throw IOException("Not a valid file!")
return magic.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF"
}
/**
* Highlevel entry point: parses a `.gguf` file on disk and returns the fully
* populated [GgufMetadata] tree.
@ -104,10 +124,7 @@ internal class GgufMetadataReaderImpl(
/** Reads the 4byte magic + 4byte version; throws if magic ≠ "GGUF". */
private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion {
val magic = ByteArray(4)
if (input.read(magic) != 4) throw IOException("File too short (no magic)")
if (!magic.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46))) // "GGUF"
throw IOException("Not a GGUF file (bad magic)")
if (!ensureMagic(input)) throw InvalidFileFormatException()
return GgufMetadata.GgufVersion.fromCode(readLEUInt32(input))
}