diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceApiService.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceApiService.kt index a4c296f1e9..11398e41b6 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceApiService.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceApiService.kt @@ -1,7 +1,9 @@ package com.example.llama.data.remote import okhttp3.ResponseBody +import retrofit2.Response import retrofit2.http.GET +import retrofit2.http.HEAD import retrofit2.http.Path import retrofit2.http.Query import retrofit2.http.Streaming @@ -21,10 +23,17 @@ interface HuggingFaceApiService { @GET("api/models/{modelId}") suspend fun getModelDetails(@Path("modelId") modelId: String): HuggingFaceModelDetails - @GET("{modelId}/resolve/main/{filePath}") + @HEAD("{modelId}/resolve/main/{filename}") + suspend fun getModelFileHeader( + @Path("modelId", encoded = true) modelId: String, + @Path("filename", encoded = true) filename: String + ): Response + + @Deprecated("Use DownloadManager instead!") + @GET("{modelId}/resolve/main/{filename}") @Streaming suspend fun downloadModelFile( @Path("modelId") modelId: String, - @Path("filePath") filePath: String + @Path("filename") filename: String ): ResponseBody } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceModel.kt index f56a6c88b3..78ee31bc1d 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceModel.kt @@ -1,31 +1,50 @@ package com.example.llama.data.remote +import android.net.Uri +import androidx.core.net.toUri +import com.example.llama.di.HUGGINGFACE_HOST import java.util.Date +private const val FILE_EXTENSION_GGUF = ".gguf" + data class HuggingFaceModel( val _id: String, val id: String, val modelId: String, val author: String, - val createdAt: Date?, - val lastModified: Date?, + val createdAt: Date, + val lastModified: Date, + + val pipeline_tag: String, + val tags: List, + + val private: Boolean, + val gated: Boolean, + + val likes: Int, + val downloads: Int, + + val sha: String, + val siblings: List, val library_name: String?, - val pipeline_tag: String?, - val tags: List?, - - val private: Boolean?, - val gated: Boolean?, - - val likes: Int?, - val downloads: Int?, - - val sha: String?, - - val siblings: List?, ) { data class Sibling( val rfilename: String, ) + + fun getGgufFilename(): String? = + siblings.map { it.rfilename }.first { it.endsWith(FILE_EXTENSION_GGUF) } + + fun toDownloadInfo() = getGgufFilename()?.let { HuggingFaceDownloadInfo(_id, modelId, it) } +} + +data class HuggingFaceDownloadInfo( + val _id: String, + val modelId: String, + val filename: String, +) { + val uri: Uri + get() = "$HUGGINGFACE_HOST${modelId}/resolve/main/$filename".toUri() } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceRemoteDataSource.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceRemoteDataSource.kt index 7f77069dec..8f3434b74a 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceRemoteDataSource.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceRemoteDataSource.kt @@ -1,25 +1,47 @@ package com.example.llama.data.remote +import android.app.DownloadManager +import android.content.Context +import android.os.Environment import android.util.Log import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.delay import kotlinx.coroutines.withContext -import java.io.File import javax.inject.Inject import javax.inject.Singleton +private const val QUERY_Q4_0_GGUF = "gguf q4_0" +private const val FILTER_TEXT_GENERATION = "text-generation" +private const val SORT_BY_DOWNLOADS = "downloads" +private const val SEARCH_RESULT_LIMIT = 20 + interface HuggingFaceRemoteDataSource { + /** + * Query openly available Q4_0 GGUF models on HuggingFace + */ suspend fun searchModels( - query: String? = "gguf q4_0", - filter: String? = "text-generation", // Only generative models, - sort: String? = "downloads", + query: String? = QUERY_Q4_0_GGUF, + filter: String? = FILTER_TEXT_GENERATION, + sort: String? = SORT_BY_DOWNLOADS, direction: String? = "-1", - limit: Int? = 20, + limit: Int? = SEARCH_RESULT_LIMIT, full: Boolean = true, ): List suspend fun getModelDetails(modelId: String): HuggingFaceModelDetails - suspend fun downloadModelFile(modelId: String, filePath: String, outputFile: File): Result + /** + * Obtain selected HuggingFace model's GGUF file size from HTTP header + */ + suspend fun getFileSize(modelId: String, filePath: String): Long? + + /** + * Download selected HuggingFace model's GGUF file via DownloadManager + */ + suspend fun downloadModelFile( + context: Context, + downloadInfo: HuggingFaceDownloadInfo, + ): Result } @Singleton @@ -42,7 +64,7 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor( direction = direction, limit = limit, full = full, - ) + ).filter { it.gated != true && it.private != true } } override suspend fun getModelDetails( @@ -51,32 +73,98 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor( apiService.getModelDetails(modelId) } - override suspend fun downloadModelFile( + override suspend fun getFileSize( modelId: String, - filePath: String, - outputFile: File - ): Result = withContext(Dispatchers.IO) { + filePath: String + ): Long? = withContext(Dispatchers.IO) { try { - val response = apiService.downloadModelFile(modelId, filePath) - - // Create parent directories if needed - outputFile.parentFile?.mkdirs() - - // Save the file - response.byteStream().use { input -> - outputFile.outputStream().use { output -> - input.copyTo(output) + apiService.getModelFileHeader(modelId, filePath).let { + if (it.isSuccessful) { + it.headers()[HTTP_HEADER_CONTENT_LENGTH]?.toLongOrNull() + } else { + null } } - - Result.success(outputFile) } catch (e: Exception) { - Log.e(TAG, "Error downloading file $filePath: ${e.message}") + Log.e(TAG, "Error getting file size for $modelId: ${e.message}") + null + } + } + + override suspend fun downloadModelFile( + context: Context, + downloadInfo: HuggingFaceDownloadInfo, + ): Result = withContext(Dispatchers.IO) { + try { + val downloadManager = + context.getSystemService(Context.DOWNLOAD_SERVICE) as DownloadManager + val request = DownloadManager.Request(downloadInfo.uri).apply { + setTitle("HuggingFace model download") + setDescription("Downloading ${downloadInfo.filename}") + setNotificationVisibility(DownloadManager.Request.VISIBILITY_VISIBLE_NOTIFY_COMPLETED) + setDestinationInExternalPublicDir( + Environment.DIRECTORY_DOWNLOADS, + downloadInfo.filename + ) + setAllowedNetworkTypes( + DownloadManager.Request.NETWORK_WIFI or DownloadManager.Request.NETWORK_MOBILE + ) + setAllowedOverMetered(true) + setAllowedOverRoaming(false) + } + Log.d(TAG, "Enqueuing download request for: ${downloadInfo.modelId}") + val downloadId = downloadManager.enqueue(request) + + delay(DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY) + + val cursor = downloadManager.query(DownloadManager.Query().setFilterById(downloadId)) + if (cursor != null && cursor.moveToFirst()) { + val statusIndex = cursor.getColumnIndex(DownloadManager.COLUMN_STATUS) + if (statusIndex >= 0) { + val status = cursor.getInt(statusIndex) + cursor.close() + + when (status) { + DownloadManager.STATUS_FAILED -> { + // Get failure reason if available + val reasonIndex = cursor.getColumnIndex(DownloadManager.COLUMN_REASON) + val reason = if (reasonIndex >= 0) cursor.getInt(reasonIndex) else -1 + val errorMessage = when (reason) { + DownloadManager.ERROR_HTTP_DATA_ERROR -> "HTTP error" + DownloadManager.ERROR_INSUFFICIENT_SPACE -> "Insufficient storage" + DownloadManager.ERROR_TOO_MANY_REDIRECTS -> "Too many redirects" + DownloadManager.ERROR_UNHANDLED_HTTP_CODE -> "Unhandled HTTP code" + DownloadManager.ERROR_CANNOT_RESUME -> "Cannot resume download" + DownloadManager.ERROR_FILE_ERROR -> "File error" + else -> "Unknown error" + } + Result.failure(Exception(errorMessage)) + } + else -> { + // Download is pending, paused, or running + Result.success(Unit) + } + } + } else { + // Assume success if we can't check status + cursor.close() + Result.success(Unit) + } + } else { + // Assume success if cursor is empty + cursor?.close() + Result.success(Unit) + } + } catch (e: Exception) { + Log.e(TAG, "Failed to enqueue download: ${e.message}") Result.failure(e) } } companion object { private val TAG = HuggingFaceRemoteDataSourceImpl::class.java.simpleName + + private const val HTTP_HEADER_CONTENT_LENGTH = "content-length" + private const val DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY = 500L } } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/repository/ModelRepository.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/repository/ModelRepository.kt index 64ec4d6d06..a0efa04aea 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/data/repository/ModelRepository.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/repository/ModelRepository.kt @@ -9,6 +9,7 @@ import com.example.llama.data.local.dao.ModelDao import com.example.llama.data.local.entity.ModelEntity import com.example.llama.data.model.GgufMetadata import com.example.llama.data.model.ModelInfo +import com.example.llama.data.remote.HuggingFaceDownloadInfo import com.example.llama.data.remote.HuggingFaceModel import com.example.llama.data.remote.HuggingFaceModelDetails import com.example.llama.data.remote.HuggingFaceRemoteDataSource @@ -92,13 +93,18 @@ interface ModelRepository { */ suspend fun getHuggingFaceModelDetails(modelId: String): HuggingFaceModelDetails + /** + * Obtain the model's size from HTTP response header + */ + suspend fun getHuggingFaceModelFileSize(downloadInfo: HuggingFaceDownloadInfo): Long? + /** * Download and import a HuggingFace model */ suspend fun importHuggingFaceModel( - model: HuggingFaceModel, - progressTracker: (Float) -> Unit - ): ModelInfo + downloadInfo: HuggingFaceDownloadInfo, + actualSize: Long, + ): Result } class InsufficientStorageException(message: String) : IOException(message) @@ -171,7 +177,9 @@ class ModelRepositoryImpl @Inject constructor( val fileSize = size ?: getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File size N/A") if (!hasEnoughSpaceForImport(fileSize)) { throw InsufficientStorageException( - "Not enough storage space. Required: ${formatFileByteSize(fileSize)}, Available: ${formatFileByteSize(availableSpaceBytes)}" + "Not enough storage space! " + + "Required: ${formatFileByteSize(fileSize)}, " + + "Available: ${formatFileByteSize(availableSpaceBytes)}" ) } @@ -265,7 +273,7 @@ class ModelRepositoryImpl @Inject constructor( // Add this method to ModelRepositoryImpl.kt private fun hasEnoughSpaceForImport(fileSize: Long): Boolean { val availableSpace = availableSpaceBytes - val requiredSpace = (fileSize * MODEL_IMPORT_SPACE_BUFFER_SCALE ).toLong() + val requiredSpace = (fileSize * MODEL_IMPORT_SPACE_BUFFER_SCALE).toLong() return availableSpace >= requiredSpace } @@ -343,125 +351,43 @@ class ModelRepositoryImpl @Inject constructor( override suspend fun searchHuggingFaceModels( limit: Int - ): List = withContext(Dispatchers.Default) { + ): List = withContext(Dispatchers.IO) { huggingFaceRemoteDataSource.searchModels(limit = limit) } override suspend fun getHuggingFaceModelDetails( modelId: String - ) = withContext(Dispatchers.Default) { + ) = withContext(Dispatchers.IO) { huggingFaceRemoteDataSource.getModelDetails(modelId) } + override suspend fun getHuggingFaceModelFileSize( + downloadInfo: HuggingFaceDownloadInfo, + ): Long? = withContext(Dispatchers.IO) { + huggingFaceRemoteDataSource.getFileSize(downloadInfo.modelId, downloadInfo.filename) + } + override suspend fun importHuggingFaceModel( - model: HuggingFaceModel, - progressTracker: (Float) -> Unit - ): ModelInfo = withContext(Dispatchers.IO) { - try { - // Find GGUF files in the model repository - val modelFiles = findGgufFiles(model.id) - if (modelFiles.isEmpty()) { - throw IOException("No GGUF files found for model ${model.id}") - } - - // Create directory for the model - val modelDir = File(modelsDir, model.id.replace("/", "_")) - if (!modelDir.exists()) { - modelDir.mkdirs() - } - - var totalDownloaded = 0L - var fileIndex = 0 - val totalFiles = modelFiles.size - - // Download each GGUF file - modelFiles.forEach { filePath -> - fileIndex++ - progressTracker(fileIndex.toFloat() / totalFiles) - - val outputFile = File(modelDir, filePath.substringAfterLast("/")) - huggingFaceRemoteDataSource.downloadModelFile( - model.id, filePath, outputFile - ).onSuccess { file -> - totalDownloaded += file.length() - } - } - - // Create and save the model entity - val modelEntity = ModelEntity( - id = UUID.randomUUID().toString(), // Generate a new ID for local storage - name = model.modelId.split("/").last(), - path = modelDir.absolutePath, - sizeInBytes = totalDownloaded, - metadata = createGgufMetadataFromHfModel(model), - dateAdded = System.currentTimeMillis(), - dateLastUsed = null + downloadInfo: HuggingFaceDownloadInfo, + actualSize: Long, + ): Result = withContext(Dispatchers.IO) { + if (!hasEnoughSpaceForImport(actualSize)) { + throw InsufficientStorageException( + "Not enough storage space! " + + "Estimated required: ${formatFileByteSize(actualSize)}, " + + "Available: ${formatFileByteSize(availableSpaceBytes)}" ) - - modelDao.insertModel(modelEntity) - return@withContext modelEntity.toModelInfo() - - } catch (e: Exception) { - // Clean up if download fails - File(modelsDir, model.id.replace("/", "_")).let { - if (it.exists()) { - it.deleteRecursively() - } - } - throw e } - } - // TODO-han.yin: replace this heuristic approach with API call - private suspend fun findGgufFiles(modelId: String): List = withContext(Dispatchers.IO) { - // In a real implementation, we would query the Hugging Face API or use Git APIs - // to get the file listing. For now, we'll use a simpler approach: - - // 1. Try standard filenames first - val standardNames = listOf( - "model.gguf", - "model-q4_0.gguf", - "model-q4_k_m.gguf", - "model-q5_k_m.gguf", - "model-q8_0.gguf" - ) - - // 2. Check if standard files exist (would require separate API call) - // For now, just return the standard names as a fallback - standardNames - } - - /** - * Convert HuggingFace model metadata to GgufMetadata format - * - * TODO-han.yin: improve this conversion coverage - */ - private fun createGgufMetadataFromHfModel(model: HuggingFaceModel) = - GgufMetadata( - version = GgufMetadata.GgufVersion.VALIDATED_V3, - tensorCount = 0, - kvCount = 0, - basic = GgufMetadata.BasicInfo( - uuid = model.id, - name = model.modelId.split("/").last(), - nameLabel = model.modelId, - sizeLabel = extractModelSize(model.tags) - ), - additional = GgufMetadata.AdditionalInfo( - type = model.pipeline_tag, - description = null, - tags = model.tags, - languages = model.tags?.filter { it.length <= 3 } // language codes + try { + huggingFaceRemoteDataSource.downloadModelFile( + context = context, + downloadInfo = downloadInfo, ) - ) - - private fun extractModelSize(tags: List?): String? { - // Try to extract model size from tags - if (tags == null) return null - - // Common model size patterns: 7B, 13B, 70B, etc. - val sizePattern = Regex("\\d+(\\.\\d+)?[Bb]") - return tags.find { it.matches(sizePattern) } + } catch (e: Exception) { + Log.e(TAG, "Import failed: ${e.message}") + Result.failure(e) + } } companion object {