diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/ModelRepository.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/ModelRepository.kt index 4e8b123824..1c874eb1f9 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/ModelRepository.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/ModelRepository.kt @@ -7,8 +7,11 @@ import android.util.Log import com.example.llama.revamp.data.local.dao.ModelDao import com.example.llama.revamp.data.local.entity.ModelEntity import com.example.llama.revamp.data.model.ModelInfo +import com.example.llama.revamp.data.remote.HuggingFaceModel +import com.example.llama.revamp.data.remote.HuggingFaceRemoteDataSource import com.example.llama.revamp.data.repository.ModelRepository.ImportProgressTracker import com.example.llama.revamp.monitoring.StorageMetrics +import com.example.llama.revamp.util.GgufMetadata import com.example.llama.revamp.util.GgufMetadataReader import com.example.llama.revamp.util.copyWithBuffer import com.example.llama.revamp.util.copyWithChannels @@ -77,6 +80,24 @@ interface ModelRepository { */ suspend fun deleteModel(modelId: String) suspend fun deleteModels(modelIds: List) + + /** + * Search models on HuggingFace + */ + suspend fun searchHuggingFaceModels(limit: Int = 20): List + + /** + * Obtain the model details from HuggingFace + */ + suspend fun getHuggingFaceModelDetails(modelId: String): HuggingFaceModel + + /** + * Download and import a HuggingFace model + */ + suspend fun importHuggingFaceModel( + model: HuggingFaceModel, + progressTracker: (Float) -> Unit + ): ModelInfo } class InsufficientStorageException(message: String) : IOException(message) @@ -85,6 +106,7 @@ class InsufficientStorageException(message: String) : IOException(message) class ModelRepositoryImpl @Inject constructor( @ApplicationContext private val context: Context, private val modelDao: ModelDao, + private val huggingFaceRemoteDataSource: HuggingFaceRemoteDataSource, ) : ModelRepository { private val modelsDir = File(context.filesDir, INTERNAL_STORAGE_PATH) @@ -93,6 +115,17 @@ class ModelRepositoryImpl @Inject constructor( if (!modelsDir.exists()) { modelsDir.mkdirs() } } + val modelsSizeBytes: Long + get() = modelsDir.listFiles()?.fold(0L) { totalSize, file -> + totalSize + if (file.isFile) file.length() else 0 + } ?: 0L + + val availableSpaceBytes: Long + get() = StatFs(context.filesDir.path).availableBytes + + val totalSpaceBytes: Long + get() = StatFs(context.filesDir.path).totalBytes + override fun getStorageMetrics(): Flow = flow { while (true) { emit( @@ -306,16 +339,128 @@ class ModelRepositoryImpl @Inject constructor( } } - val modelsSizeBytes: Long - get() = modelsDir.listFiles()?.fold(0L) { totalSize, file -> - totalSize + if (file.isFile) file.length() else 0 - } ?: 0L + override suspend fun searchHuggingFaceModels( + limit: Int + ): List = withContext(Dispatchers.Default) { + huggingFaceRemoteDataSource.searchModels(limit = limit) + } - val availableSpaceBytes: Long - get() = StatFs(context.filesDir.path).availableBytes + override suspend fun getHuggingFaceModelDetails( + modelId: String + ) = withContext(Dispatchers.Default) { + huggingFaceRemoteDataSource.getModelDetails(modelId) + } - val totalSpaceBytes: Long - get() = StatFs(context.filesDir.path).totalBytes + override suspend fun importHuggingFaceModel( + model: HuggingFaceModel, + progressTracker: (Float) -> Unit + ): ModelInfo = withContext(Dispatchers.IO) { + try { + // Find GGUF files in the model repository + val modelFiles = findGgufFiles(model.id) + if (modelFiles.isEmpty()) { + throw IOException("No GGUF files found for model ${model.id}") + } + + // Create directory for the model + val modelDir = File(modelsDir, model.id.replace("/", "_")) + if (!modelDir.exists()) { + modelDir.mkdirs() + } + + var totalDownloaded = 0L + var fileIndex = 0 + val totalFiles = modelFiles.size + + // Download each GGUF file + modelFiles.forEach { filePath -> + fileIndex++ + progressTracker(fileIndex.toFloat() / totalFiles) + + val outputFile = File(modelDir, filePath.substringAfterLast("/")) + huggingFaceRemoteDataSource.downloadModelFile( + model.id, filePath, outputFile + ).onSuccess { file -> + totalDownloaded += file.length() + } + } + + // Create and save the model entity + val modelEntity = ModelEntity( + id = UUID.randomUUID().toString(), // Generate a new ID for local storage + name = model.modelId.split("/").last(), + path = modelDir.absolutePath, + sizeInBytes = totalDownloaded, + metadata = createGgufMetadataFromHfModel(model), + dateAdded = System.currentTimeMillis(), + dateLastUsed = null + ) + + modelDao.insertModel(modelEntity) + return@withContext modelEntity.toModelInfo() + + } catch (e: Exception) { + // Clean up if download fails + File(modelsDir, model.id.replace("/", "_")).let { + if (it.exists()) { + it.deleteRecursively() + } + } + throw e + } + } + + // TODO-han.yin: replace this heuristic approach with API call + private suspend fun findGgufFiles(modelId: String): List = withContext(Dispatchers.IO) { + // In a real implementation, we would query the Hugging Face API or use Git APIs + // to get the file listing. For now, we'll use a simpler approach: + + // 1. Try standard filenames first + val standardNames = listOf( + "model.gguf", + "model-q4_0.gguf", + "model-q4_k_m.gguf", + "model-q5_k_m.gguf", + "model-q8_0.gguf" + ) + + // 2. Check if standard files exist (would require separate API call) + // For now, just return the standard names as a fallback + standardNames + } + + /** + * Convert HuggingFace model metadata to GgufMetadata format + * + * TODO-han.yin: improve this conversion coverage + */ + private fun createGgufMetadataFromHfModel(model: HuggingFaceModel) = + GgufMetadata( + version = GgufMetadata.GgufVersion.VALIDATED_V3, + tensorCount = 0, + kvCount = 0, + basic = GgufMetadata.BasicInfo( + uuid = model.id, + name = model.modelId.split("/").last(), + nameLabel = model.modelId, + sizeLabel = extractModelSize(model.tags) + ), + additional = GgufMetadata.AdditionalInfo( + type = model.pipeline_tag, + description = null, + tags = model.tags, + languages = model.tags?.filter { it.length <= 3 } // language codes + ) + ) + + private fun extractModelSize(tags: List?): String? { + // Try to extract model size from tags + if (tags == null) return null + + // Common model size patterns: 7B, 13B, 70B, etc. + val sizePattern = Regex("\\d+(\\.\\d+)?[Bb]") + return tags.find { it.matches(sizePattern) } + } companion object { private val TAG = ModelRepository::class.java.simpleName @@ -332,9 +477,6 @@ class ModelRepositoryImpl @Inject constructor( private const val DEFAULT_BUFFER_SIZE = 4 * 1024 * 1024 private const val DEFAULT_YIELD_SIZE = 16 * 1024 * 1024 private const val CANCEL_LOCAL_MODEL_IMPORT_TIMEOUT = 500L - - private const val DEFAULT_CONTEXT_SIZE = 8192 - } }