data: update Model data repository to support fetching HuggingFace models
This commit is contained in:
parent
48fa0b23dc
commit
cfbd271c84
|
|
@ -7,8 +7,11 @@ import android.util.Log
|
||||||
import com.example.llama.revamp.data.local.dao.ModelDao
|
import com.example.llama.revamp.data.local.dao.ModelDao
|
||||||
import com.example.llama.revamp.data.local.entity.ModelEntity
|
import com.example.llama.revamp.data.local.entity.ModelEntity
|
||||||
import com.example.llama.revamp.data.model.ModelInfo
|
import com.example.llama.revamp.data.model.ModelInfo
|
||||||
|
import com.example.llama.revamp.data.remote.HuggingFaceModel
|
||||||
|
import com.example.llama.revamp.data.remote.HuggingFaceRemoteDataSource
|
||||||
import com.example.llama.revamp.data.repository.ModelRepository.ImportProgressTracker
|
import com.example.llama.revamp.data.repository.ModelRepository.ImportProgressTracker
|
||||||
import com.example.llama.revamp.monitoring.StorageMetrics
|
import com.example.llama.revamp.monitoring.StorageMetrics
|
||||||
|
import com.example.llama.revamp.util.GgufMetadata
|
||||||
import com.example.llama.revamp.util.GgufMetadataReader
|
import com.example.llama.revamp.util.GgufMetadataReader
|
||||||
import com.example.llama.revamp.util.copyWithBuffer
|
import com.example.llama.revamp.util.copyWithBuffer
|
||||||
import com.example.llama.revamp.util.copyWithChannels
|
import com.example.llama.revamp.util.copyWithChannels
|
||||||
|
|
@ -77,6 +80,24 @@ interface ModelRepository {
|
||||||
*/
|
*/
|
||||||
suspend fun deleteModel(modelId: String)
|
suspend fun deleteModel(modelId: String)
|
||||||
suspend fun deleteModels(modelIds: List<String>)
|
suspend fun deleteModels(modelIds: List<String>)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Search models on HuggingFace
|
||||||
|
*/
|
||||||
|
suspend fun searchHuggingFaceModels(limit: Int = 20): List<HuggingFaceModel>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Obtain the model details from HuggingFace
|
||||||
|
*/
|
||||||
|
suspend fun getHuggingFaceModelDetails(modelId: String): HuggingFaceModel
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Download and import a HuggingFace model
|
||||||
|
*/
|
||||||
|
suspend fun importHuggingFaceModel(
|
||||||
|
model: HuggingFaceModel,
|
||||||
|
progressTracker: (Float) -> Unit
|
||||||
|
): ModelInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
class InsufficientStorageException(message: String) : IOException(message)
|
class InsufficientStorageException(message: String) : IOException(message)
|
||||||
|
|
@ -85,6 +106,7 @@ class InsufficientStorageException(message: String) : IOException(message)
|
||||||
class ModelRepositoryImpl @Inject constructor(
|
class ModelRepositoryImpl @Inject constructor(
|
||||||
@ApplicationContext private val context: Context,
|
@ApplicationContext private val context: Context,
|
||||||
private val modelDao: ModelDao,
|
private val modelDao: ModelDao,
|
||||||
|
private val huggingFaceRemoteDataSource: HuggingFaceRemoteDataSource,
|
||||||
) : ModelRepository {
|
) : ModelRepository {
|
||||||
|
|
||||||
private val modelsDir = File(context.filesDir, INTERNAL_STORAGE_PATH)
|
private val modelsDir = File(context.filesDir, INTERNAL_STORAGE_PATH)
|
||||||
|
|
@ -93,6 +115,17 @@ class ModelRepositoryImpl @Inject constructor(
|
||||||
if (!modelsDir.exists()) { modelsDir.mkdirs() }
|
if (!modelsDir.exists()) { modelsDir.mkdirs() }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val modelsSizeBytes: Long
|
||||||
|
get() = modelsDir.listFiles()?.fold(0L) { totalSize, file ->
|
||||||
|
totalSize + if (file.isFile) file.length() else 0
|
||||||
|
} ?: 0L
|
||||||
|
|
||||||
|
val availableSpaceBytes: Long
|
||||||
|
get() = StatFs(context.filesDir.path).availableBytes
|
||||||
|
|
||||||
|
val totalSpaceBytes: Long
|
||||||
|
get() = StatFs(context.filesDir.path).totalBytes
|
||||||
|
|
||||||
override fun getStorageMetrics(): Flow<StorageMetrics> = flow {
|
override fun getStorageMetrics(): Flow<StorageMetrics> = flow {
|
||||||
while (true) {
|
while (true) {
|
||||||
emit(
|
emit(
|
||||||
|
|
@ -306,16 +339,128 @@ class ModelRepositoryImpl @Inject constructor(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val modelsSizeBytes: Long
|
override suspend fun searchHuggingFaceModels(
|
||||||
get() = modelsDir.listFiles()?.fold(0L) { totalSize, file ->
|
limit: Int
|
||||||
totalSize + if (file.isFile) file.length() else 0
|
): List<HuggingFaceModel> = withContext(Dispatchers.Default) {
|
||||||
} ?: 0L
|
huggingFaceRemoteDataSource.searchModels(limit = limit)
|
||||||
|
}
|
||||||
|
|
||||||
val availableSpaceBytes: Long
|
override suspend fun getHuggingFaceModelDetails(
|
||||||
get() = StatFs(context.filesDir.path).availableBytes
|
modelId: String
|
||||||
|
) = withContext(Dispatchers.Default) {
|
||||||
|
huggingFaceRemoteDataSource.getModelDetails(modelId)
|
||||||
|
}
|
||||||
|
|
||||||
val totalSpaceBytes: Long
|
override suspend fun importHuggingFaceModel(
|
||||||
get() = StatFs(context.filesDir.path).totalBytes
|
model: HuggingFaceModel,
|
||||||
|
progressTracker: (Float) -> Unit
|
||||||
|
): ModelInfo = withContext(Dispatchers.IO) {
|
||||||
|
try {
|
||||||
|
// Find GGUF files in the model repository
|
||||||
|
val modelFiles = findGgufFiles(model.id)
|
||||||
|
if (modelFiles.isEmpty()) {
|
||||||
|
throw IOException("No GGUF files found for model ${model.id}")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create directory for the model
|
||||||
|
val modelDir = File(modelsDir, model.id.replace("/", "_"))
|
||||||
|
if (!modelDir.exists()) {
|
||||||
|
modelDir.mkdirs()
|
||||||
|
}
|
||||||
|
|
||||||
|
var totalDownloaded = 0L
|
||||||
|
var fileIndex = 0
|
||||||
|
val totalFiles = modelFiles.size
|
||||||
|
|
||||||
|
// Download each GGUF file
|
||||||
|
modelFiles.forEach { filePath ->
|
||||||
|
fileIndex++
|
||||||
|
progressTracker(fileIndex.toFloat() / totalFiles)
|
||||||
|
|
||||||
|
val outputFile = File(modelDir, filePath.substringAfterLast("/"))
|
||||||
|
huggingFaceRemoteDataSource.downloadModelFile(
|
||||||
|
model.id, filePath, outputFile
|
||||||
|
).onSuccess { file ->
|
||||||
|
totalDownloaded += file.length()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and save the model entity
|
||||||
|
val modelEntity = ModelEntity(
|
||||||
|
id = UUID.randomUUID().toString(), // Generate a new ID for local storage
|
||||||
|
name = model.modelId.split("/").last(),
|
||||||
|
path = modelDir.absolutePath,
|
||||||
|
sizeInBytes = totalDownloaded,
|
||||||
|
metadata = createGgufMetadataFromHfModel(model),
|
||||||
|
dateAdded = System.currentTimeMillis(),
|
||||||
|
dateLastUsed = null
|
||||||
|
)
|
||||||
|
|
||||||
|
modelDao.insertModel(modelEntity)
|
||||||
|
return@withContext modelEntity.toModelInfo()
|
||||||
|
|
||||||
|
} catch (e: Exception) {
|
||||||
|
// Clean up if download fails
|
||||||
|
File(modelsDir, model.id.replace("/", "_")).let {
|
||||||
|
if (it.exists()) {
|
||||||
|
it.deleteRecursively()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO-han.yin: replace this heuristic approach with API call
|
||||||
|
private suspend fun findGgufFiles(modelId: String): List<String> = withContext(Dispatchers.IO) {
|
||||||
|
// In a real implementation, we would query the Hugging Face API or use Git APIs
|
||||||
|
// to get the file listing. For now, we'll use a simpler approach:
|
||||||
|
|
||||||
|
// 1. Try standard filenames first
|
||||||
|
val standardNames = listOf(
|
||||||
|
"model.gguf",
|
||||||
|
"model-q4_0.gguf",
|
||||||
|
"model-q4_k_m.gguf",
|
||||||
|
"model-q5_k_m.gguf",
|
||||||
|
"model-q8_0.gguf"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 2. Check if standard files exist (would require separate API call)
|
||||||
|
// For now, just return the standard names as a fallback
|
||||||
|
standardNames
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert HuggingFace model metadata to GgufMetadata format
|
||||||
|
*
|
||||||
|
* TODO-han.yin: improve this conversion coverage
|
||||||
|
*/
|
||||||
|
private fun createGgufMetadataFromHfModel(model: HuggingFaceModel) =
|
||||||
|
GgufMetadata(
|
||||||
|
version = GgufMetadata.GgufVersion.VALIDATED_V3,
|
||||||
|
tensorCount = 0,
|
||||||
|
kvCount = 0,
|
||||||
|
basic = GgufMetadata.BasicInfo(
|
||||||
|
uuid = model.id,
|
||||||
|
name = model.modelId.split("/").last(),
|
||||||
|
nameLabel = model.modelId,
|
||||||
|
sizeLabel = extractModelSize(model.tags)
|
||||||
|
),
|
||||||
|
additional = GgufMetadata.AdditionalInfo(
|
||||||
|
type = model.pipeline_tag,
|
||||||
|
description = null,
|
||||||
|
tags = model.tags,
|
||||||
|
languages = model.tags?.filter { it.length <= 3 } // language codes
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
private fun extractModelSize(tags: List<String>?): String? {
|
||||||
|
// Try to extract model size from tags
|
||||||
|
if (tags == null) return null
|
||||||
|
|
||||||
|
// Common model size patterns: 7B, 13B, 70B, etc.
|
||||||
|
val sizePattern = Regex("\\d+(\\.\\d+)?[Bb]")
|
||||||
|
return tags.find { it.matches(sizePattern) }
|
||||||
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
private val TAG = ModelRepository::class.java.simpleName
|
private val TAG = ModelRepository::class.java.simpleName
|
||||||
|
|
@ -332,9 +477,6 @@ class ModelRepositoryImpl @Inject constructor(
|
||||||
private const val DEFAULT_BUFFER_SIZE = 4 * 1024 * 1024
|
private const val DEFAULT_BUFFER_SIZE = 4 * 1024 * 1024
|
||||||
private const val DEFAULT_YIELD_SIZE = 16 * 1024 * 1024
|
private const val DEFAULT_YIELD_SIZE = 16 * 1024 * 1024
|
||||||
private const val CANCEL_LOCAL_MODEL_IMPORT_TIMEOUT = 500L
|
private const val CANCEL_LOCAL_MODEL_IMPORT_TIMEOUT = 500L
|
||||||
|
|
||||||
private const val DEFAULT_CONTEXT_SIZE = 8192
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue