lib: read & validate the magic number from the picked source file before executing the import
This commit is contained in:
parent
7968216235
commit
381994234c
|
|
@ -2,6 +2,7 @@ package com.example.llama.data.repo
|
||||||
|
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
import android.llama.cpp.gguf.GgufMetadataReader
|
import android.llama.cpp.gguf.GgufMetadataReader
|
||||||
|
import android.llama.cpp.gguf.InvalidFileFormatException
|
||||||
import android.net.Uri
|
import android.net.Uri
|
||||||
import android.os.StatFs
|
import android.os.StatFs
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
|
|
@ -171,9 +172,15 @@ class ModelRepositoryImpl @Inject constructor(
|
||||||
throw IllegalStateException("Another import is already in progress!")
|
throw IllegalStateException("Another import is already in progress!")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check file info
|
||||||
val fileInfo = localFileDataSource.getFileInfo(uri)
|
val fileInfo = localFileDataSource.getFileInfo(uri)
|
||||||
val fileSize = size ?: fileInfo?.size ?: throw FileNotFoundException("File size N/A")
|
val fileSize = size ?: fileInfo?.size ?: throw FileNotFoundException("File size N/A")
|
||||||
val fileName = name ?: fileInfo?.name ?: throw FileNotFoundException("File name N/A")
|
val fileName = name ?: fileInfo?.name ?: throw FileNotFoundException("File name N/A")
|
||||||
|
if (!ggufMetadataReader.ensureSourceFileFormat(context, uri)) {
|
||||||
|
throw InvalidFileFormatException()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for enough storage
|
||||||
if (!hasEnoughSpaceForImport(fileSize)) {
|
if (!hasEnoughSpaceForImport(fileSize)) {
|
||||||
throw InsufficientStorageException(
|
throw InsufficientStorageException(
|
||||||
"Not enough storage space! " +
|
"Not enough storage space! " +
|
||||||
|
|
@ -182,7 +189,6 @@ class ModelRepositoryImpl @Inject constructor(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
val modelFile = File(modelsDir, fileName)
|
val modelFile = File(modelsDir, fileName)
|
||||||
|
|
||||||
importJob = coroutineContext[Job]
|
importJob = coroutineContext[Job]
|
||||||
currentModelFile = modelFile
|
currentModelFile = modelFile
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import android.content.Context
|
||||||
import android.content.Context.RECEIVER_EXPORTED
|
import android.content.Context.RECEIVER_EXPORTED
|
||||||
import android.content.Intent
|
import android.content.Intent
|
||||||
import android.content.IntentFilter
|
import android.content.IntentFilter
|
||||||
|
import android.llama.cpp.gguf.InvalidFileFormatException
|
||||||
import android.net.Uri
|
import android.net.Uri
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
import androidx.lifecycle.ViewModel
|
import androidx.lifecycle.ViewModel
|
||||||
|
|
@ -208,13 +209,18 @@ class ModelsManagementViewModel @Inject constructor(
|
||||||
_managementState.value = Importation.Importing(progress, fileName, fileSize)
|
_managementState.value = Importation.Importing(progress, fileName, fileSize)
|
||||||
}
|
}
|
||||||
_managementState.value = Importation.Success(model)
|
_managementState.value = Importation.Success(model)
|
||||||
|
} catch (_: InvalidFileFormatException) {
|
||||||
|
_managementState.value = Importation.Error(
|
||||||
|
message = "Not a valid GGUF model!",
|
||||||
|
)
|
||||||
} catch (e: InsufficientStorageException) {
|
} catch (e: InsufficientStorageException) {
|
||||||
_managementState.value = Importation.Error(
|
_managementState.value = Importation.Error(
|
||||||
message = e.message ?: "Insufficient storage space to import $uri",
|
message = e.message ?: "Insufficient storage space to import $fileName",
|
||||||
)
|
)
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "Unknown exception importing $fileName", e)
|
||||||
_managementState.value = Importation.Error(
|
_managementState.value = Importation.Error(
|
||||||
message = e.message ?: "Unknown error importing $uri",
|
message = e.message ?: "Unknown error importing $fileName",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -376,6 +382,7 @@ sealed class ModelManagementState {
|
||||||
data class Confirming(val uri: Uri, val fileName: String, val fileSize: Long) : Importation()
|
data class Confirming(val uri: Uri, val fileName: String, val fileSize: Long) : Importation()
|
||||||
data class Importing(val progress: Float = 0f, val fileName: String, val fileSize: Long, val isCancelling: Boolean = false) : Importation()
|
data class Importing(val progress: Float = 0f, val fileName: String, val fileSize: Long, val isCancelling: Boolean = false) : Importation()
|
||||||
data class Success(val model: ModelInfo) : Importation()
|
data class Success(val model: ModelInfo) : Importation()
|
||||||
|
// TODO-han.yin: Add an optional explanation URL for more info!
|
||||||
data class Error(val message: String) : Importation()
|
data class Error(val message: String) : Importation()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
package android.llama.cpp.gguf
|
package android.llama.cpp.gguf
|
||||||
|
|
||||||
|
import android.content.Context
|
||||||
import android.llama.cpp.internal.gguf.GgufMetadataReaderImpl
|
import android.llama.cpp.internal.gguf.GgufMetadataReaderImpl
|
||||||
|
import android.net.Uri
|
||||||
import java.io.IOException
|
import java.io.IOException
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -8,13 +10,23 @@ import java.io.IOException
|
||||||
* Use `GgufMetadataReader.create()` to get an instance.
|
* Use `GgufMetadataReader.create()` to get an instance.
|
||||||
*/
|
*/
|
||||||
interface GgufMetadataReader {
|
interface GgufMetadataReader {
|
||||||
|
/**
|
||||||
|
* Reads the magic number from the specified file path.
|
||||||
|
*
|
||||||
|
* @param context Context for obtaining ContentResolver
|
||||||
|
* @param uri Uri to the GGUF file provided by ContentProvider
|
||||||
|
* @return true if file is valid GGUF, otherwise false
|
||||||
|
* @throws InvalidFileFormatException if file format is invalid
|
||||||
|
*/
|
||||||
|
suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Reads and parses GGUF metadata from the specified file path.
|
* Reads and parses GGUF metadata from the specified file path.
|
||||||
*
|
*
|
||||||
* @param path The absolute path to the GGUF file
|
* @param path The absolute path to the GGUF file
|
||||||
* @return Structured metadata extracted from the file
|
* @return Structured metadata extracted from the file
|
||||||
* @throws IOException if file cannot be read
|
* @throws IOException if file is damaged or cannot be read
|
||||||
* @throws IllegalArgumentException if file format is invalid
|
* @throws InvalidFileFormatException if file format is invalid
|
||||||
*/
|
*/
|
||||||
suspend fun readStructuredMetadata(path: String): GgufMetadata
|
suspend fun readStructuredMetadata(path: String): GgufMetadata
|
||||||
|
|
||||||
|
|
@ -50,3 +62,5 @@ interface GgufMetadataReader {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class InvalidFileFormatException : IOException()
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
package android.llama.cpp.internal.gguf
|
package android.llama.cpp.internal.gguf
|
||||||
|
|
||||||
|
import android.content.Context
|
||||||
import android.llama.cpp.gguf.GgufMetadata
|
import android.llama.cpp.gguf.GgufMetadata
|
||||||
import android.llama.cpp.gguf.GgufMetadataReader
|
import android.llama.cpp.gguf.GgufMetadataReader
|
||||||
|
import android.llama.cpp.gguf.InvalidFileFormatException
|
||||||
|
import android.net.Uri
|
||||||
import java.io.File
|
import java.io.File
|
||||||
import java.io.IOException
|
import java.io.IOException
|
||||||
import java.io.InputStream
|
import java.io.InputStream
|
||||||
|
|
@ -25,7 +28,7 @@ internal class GgufMetadataReaderImpl(
|
||||||
UINT32(4), INT32(5), FLOAT32(6), BOOL(7),
|
UINT32(4), INT32(5), FLOAT32(6), BOOL(7),
|
||||||
STRING(8), ARRAY(9), UINT64(10), INT64(11), FLOAT64(12);
|
STRING(8), ARRAY(9), UINT64(10), INT64(11), FLOAT64(12);
|
||||||
companion object {
|
companion object {
|
||||||
private val codeMap = values().associateBy(MetadataType::code)
|
private val codeMap = entries.associateBy(MetadataType::code)
|
||||||
fun fromCode(code: Int): MetadataType = codeMap[code]
|
fun fromCode(code: Int): MetadataType = codeMap[code]
|
||||||
?: throw IOException("Unknown metadata value type code: $code")
|
?: throw IOException("Unknown metadata value type code: $code")
|
||||||
}
|
}
|
||||||
|
|
@ -65,6 +68,23 @@ internal class GgufMetadataReaderImpl(
|
||||||
is MetadataValue.ArrayVal -> elements.map { it.toPrimitive() }
|
is MetadataValue.ArrayVal -> elements.map { it.toPrimitive() }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reads the magic number from the specified file path.
|
||||||
|
*
|
||||||
|
* @param context Context for obtaining ContentResolver
|
||||||
|
* @param uri Uri to the GGUF file provided by ContentProvider
|
||||||
|
* @return true if file is valid GGUF, otherwise false
|
||||||
|
*/
|
||||||
|
override suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean =
|
||||||
|
context.contentResolver.openInputStream(uri)?.buffered()?.use { ensureMagic(it) } == true
|
||||||
|
|
||||||
|
/** Reads the 4‑byte magic; throws if magic ≠ "GGUF". */
|
||||||
|
private fun ensureMagic(input: InputStream): Boolean {
|
||||||
|
val magic = ByteArray(4)
|
||||||
|
if (input.read(magic) != 4) throw IOException("Not a valid file!")
|
||||||
|
return magic.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF"
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* High‑level entry point: parses a `.gguf` file on disk and returns the fully
|
* High‑level entry point: parses a `.gguf` file on disk and returns the fully
|
||||||
* populated [GgufMetadata] tree.
|
* populated [GgufMetadata] tree.
|
||||||
|
|
@ -104,10 +124,7 @@ internal class GgufMetadataReaderImpl(
|
||||||
|
|
||||||
/** Reads the 4‑byte magic + 4‑byte version; throws if magic ≠ "GGUF". */
|
/** Reads the 4‑byte magic + 4‑byte version; throws if magic ≠ "GGUF". */
|
||||||
private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion {
|
private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion {
|
||||||
val magic = ByteArray(4)
|
if (!ensureMagic(input)) throw InvalidFileFormatException()
|
||||||
if (input.read(magic) != 4) throw IOException("File too short (no magic)")
|
|
||||||
if (!magic.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46))) // "GGUF"
|
|
||||||
throw IOException("Not a GGUF file (bad magic)")
|
|
||||||
return GgufMetadata.GgufVersion.fromCode(readLEUInt32(input))
|
return GgufMetadata.GgufVersion.fromCode(readLEUInt32(input))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue