remote: refactor HuggingFaceApiService; implement download feature in HuggingFaceRemoteDataSource

This commit is contained in:
Han Yin 2025-07-06 18:30:14 -07:00
parent 5138cb6a85
commit aa22467e01
4 changed files with 192 additions and 150 deletions

View File

@ -1,7 +1,9 @@
package com.example.llama.data.remote package com.example.llama.data.remote
import okhttp3.ResponseBody import okhttp3.ResponseBody
import retrofit2.Response
import retrofit2.http.GET import retrofit2.http.GET
import retrofit2.http.HEAD
import retrofit2.http.Path import retrofit2.http.Path
import retrofit2.http.Query import retrofit2.http.Query
import retrofit2.http.Streaming import retrofit2.http.Streaming
@ -21,10 +23,17 @@ interface HuggingFaceApiService {
@GET("api/models/{modelId}") @GET("api/models/{modelId}")
suspend fun getModelDetails(@Path("modelId") modelId: String): HuggingFaceModelDetails suspend fun getModelDetails(@Path("modelId") modelId: String): HuggingFaceModelDetails
@GET("{modelId}/resolve/main/{filePath}") @HEAD("{modelId}/resolve/main/{filename}")
suspend fun getModelFileHeader(
@Path("modelId", encoded = true) modelId: String,
@Path("filename", encoded = true) filename: String
): Response<Void>
@Deprecated("Use DownloadManager instead!")
@GET("{modelId}/resolve/main/{filename}")
@Streaming @Streaming
suspend fun downloadModelFile( suspend fun downloadModelFile(
@Path("modelId") modelId: String, @Path("modelId") modelId: String,
@Path("filePath") filePath: String @Path("filename") filename: String
): ResponseBody ): ResponseBody
} }

View File

@ -1,31 +1,50 @@
package com.example.llama.data.remote package com.example.llama.data.remote
import android.net.Uri
import androidx.core.net.toUri
import com.example.llama.di.HUGGINGFACE_HOST
import java.util.Date import java.util.Date
private const val FILE_EXTENSION_GGUF = ".gguf"
data class HuggingFaceModel( data class HuggingFaceModel(
val _id: String, val _id: String,
val id: String, val id: String,
val modelId: String, val modelId: String,
val author: String, val author: String,
val createdAt: Date?, val createdAt: Date,
val lastModified: Date?, val lastModified: Date,
val pipeline_tag: String,
val tags: List<String>,
val private: Boolean,
val gated: Boolean,
val likes: Int,
val downloads: Int,
val sha: String,
val siblings: List<Sibling>,
val library_name: String?, val library_name: String?,
val pipeline_tag: String?,
val tags: List<String>?,
val private: Boolean?,
val gated: Boolean?,
val likes: Int?,
val downloads: Int?,
val sha: String?,
val siblings: List<Sibling>?,
) { ) {
data class Sibling( data class Sibling(
val rfilename: String, val rfilename: String,
) )
fun getGgufFilename(): String? =
siblings.map { it.rfilename }.first { it.endsWith(FILE_EXTENSION_GGUF) }
fun toDownloadInfo() = getGgufFilename()?.let { HuggingFaceDownloadInfo(_id, modelId, it) }
}
data class HuggingFaceDownloadInfo(
val _id: String,
val modelId: String,
val filename: String,
) {
val uri: Uri
get() = "$HUGGINGFACE_HOST${modelId}/resolve/main/$filename".toUri()
} }

View File

@ -1,25 +1,47 @@
package com.example.llama.data.remote package com.example.llama.data.remote
import android.app.DownloadManager
import android.content.Context
import android.os.Environment
import android.util.Log import android.util.Log
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import java.io.File
import javax.inject.Inject import javax.inject.Inject
import javax.inject.Singleton import javax.inject.Singleton
private const val QUERY_Q4_0_GGUF = "gguf q4_0"
private const val FILTER_TEXT_GENERATION = "text-generation"
private const val SORT_BY_DOWNLOADS = "downloads"
private const val SEARCH_RESULT_LIMIT = 20
interface HuggingFaceRemoteDataSource { interface HuggingFaceRemoteDataSource {
/**
* Query openly available Q4_0 GGUF models on HuggingFace
*/
suspend fun searchModels( suspend fun searchModels(
query: String? = "gguf q4_0", query: String? = QUERY_Q4_0_GGUF,
filter: String? = "text-generation", // Only generative models, filter: String? = FILTER_TEXT_GENERATION,
sort: String? = "downloads", sort: String? = SORT_BY_DOWNLOADS,
direction: String? = "-1", direction: String? = "-1",
limit: Int? = 20, limit: Int? = SEARCH_RESULT_LIMIT,
full: Boolean = true, full: Boolean = true,
): List<HuggingFaceModel> ): List<HuggingFaceModel>
suspend fun getModelDetails(modelId: String): HuggingFaceModelDetails suspend fun getModelDetails(modelId: String): HuggingFaceModelDetails
suspend fun downloadModelFile(modelId: String, filePath: String, outputFile: File): Result<File> /**
* Obtain selected HuggingFace model's GGUF file size from HTTP header
*/
suspend fun getFileSize(modelId: String, filePath: String): Long?
/**
* Download selected HuggingFace model's GGUF file via DownloadManager
*/
suspend fun downloadModelFile(
context: Context,
downloadInfo: HuggingFaceDownloadInfo,
): Result<Unit>
} }
@Singleton @Singleton
@ -42,7 +64,7 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor(
direction = direction, direction = direction,
limit = limit, limit = limit,
full = full, full = full,
) ).filter { it.gated != true && it.private != true }
} }
override suspend fun getModelDetails( override suspend fun getModelDetails(
@ -51,32 +73,98 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor(
apiService.getModelDetails(modelId) apiService.getModelDetails(modelId)
} }
override suspend fun downloadModelFile( override suspend fun getFileSize(
modelId: String, modelId: String,
filePath: String, filePath: String
outputFile: File ): Long? = withContext(Dispatchers.IO) {
): Result<File> = withContext(Dispatchers.IO) {
try { try {
val response = apiService.downloadModelFile(modelId, filePath) apiService.getModelFileHeader(modelId, filePath).let {
if (it.isSuccessful) {
// Create parent directories if needed it.headers()[HTTP_HEADER_CONTENT_LENGTH]?.toLongOrNull()
outputFile.parentFile?.mkdirs() } else {
null
// Save the file
response.byteStream().use { input ->
outputFile.outputStream().use { output ->
input.copyTo(output)
} }
} }
Result.success(outputFile)
} catch (e: Exception) { } catch (e: Exception) {
Log.e(TAG, "Error downloading file $filePath: ${e.message}") Log.e(TAG, "Error getting file size for $modelId: ${e.message}")
null
}
}
override suspend fun downloadModelFile(
context: Context,
downloadInfo: HuggingFaceDownloadInfo,
): Result<Unit> = withContext(Dispatchers.IO) {
try {
val downloadManager =
context.getSystemService(Context.DOWNLOAD_SERVICE) as DownloadManager
val request = DownloadManager.Request(downloadInfo.uri).apply {
setTitle("HuggingFace model download")
setDescription("Downloading ${downloadInfo.filename}")
setNotificationVisibility(DownloadManager.Request.VISIBILITY_VISIBLE_NOTIFY_COMPLETED)
setDestinationInExternalPublicDir(
Environment.DIRECTORY_DOWNLOADS,
downloadInfo.filename
)
setAllowedNetworkTypes(
DownloadManager.Request.NETWORK_WIFI or DownloadManager.Request.NETWORK_MOBILE
)
setAllowedOverMetered(true)
setAllowedOverRoaming(false)
}
Log.d(TAG, "Enqueuing download request for: ${downloadInfo.modelId}")
val downloadId = downloadManager.enqueue(request)
delay(DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY)
val cursor = downloadManager.query(DownloadManager.Query().setFilterById(downloadId))
if (cursor != null && cursor.moveToFirst()) {
val statusIndex = cursor.getColumnIndex(DownloadManager.COLUMN_STATUS)
if (statusIndex >= 0) {
val status = cursor.getInt(statusIndex)
cursor.close()
when (status) {
DownloadManager.STATUS_FAILED -> {
// Get failure reason if available
val reasonIndex = cursor.getColumnIndex(DownloadManager.COLUMN_REASON)
val reason = if (reasonIndex >= 0) cursor.getInt(reasonIndex) else -1
val errorMessage = when (reason) {
DownloadManager.ERROR_HTTP_DATA_ERROR -> "HTTP error"
DownloadManager.ERROR_INSUFFICIENT_SPACE -> "Insufficient storage"
DownloadManager.ERROR_TOO_MANY_REDIRECTS -> "Too many redirects"
DownloadManager.ERROR_UNHANDLED_HTTP_CODE -> "Unhandled HTTP code"
DownloadManager.ERROR_CANNOT_RESUME -> "Cannot resume download"
DownloadManager.ERROR_FILE_ERROR -> "File error"
else -> "Unknown error"
}
Result.failure(Exception(errorMessage))
}
else -> {
// Download is pending, paused, or running
Result.success(Unit)
}
}
} else {
// Assume success if we can't check status
cursor.close()
Result.success(Unit)
}
} else {
// Assume success if cursor is empty
cursor?.close()
Result.success(Unit)
}
} catch (e: Exception) {
Log.e(TAG, "Failed to enqueue download: ${e.message}")
Result.failure(e) Result.failure(e)
} }
} }
companion object { companion object {
private val TAG = HuggingFaceRemoteDataSourceImpl::class.java.simpleName private val TAG = HuggingFaceRemoteDataSourceImpl::class.java.simpleName
private const val HTTP_HEADER_CONTENT_LENGTH = "content-length"
private const val DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY = 500L
} }
} }

View File

@ -9,6 +9,7 @@ import com.example.llama.data.local.dao.ModelDao
import com.example.llama.data.local.entity.ModelEntity import com.example.llama.data.local.entity.ModelEntity
import com.example.llama.data.model.GgufMetadata import com.example.llama.data.model.GgufMetadata
import com.example.llama.data.model.ModelInfo import com.example.llama.data.model.ModelInfo
import com.example.llama.data.remote.HuggingFaceDownloadInfo
import com.example.llama.data.remote.HuggingFaceModel import com.example.llama.data.remote.HuggingFaceModel
import com.example.llama.data.remote.HuggingFaceModelDetails import com.example.llama.data.remote.HuggingFaceModelDetails
import com.example.llama.data.remote.HuggingFaceRemoteDataSource import com.example.llama.data.remote.HuggingFaceRemoteDataSource
@ -92,13 +93,18 @@ interface ModelRepository {
*/ */
suspend fun getHuggingFaceModelDetails(modelId: String): HuggingFaceModelDetails suspend fun getHuggingFaceModelDetails(modelId: String): HuggingFaceModelDetails
/**
* Obtain the model's size from HTTP response header
*/
suspend fun getHuggingFaceModelFileSize(downloadInfo: HuggingFaceDownloadInfo): Long?
/** /**
* Download and import a HuggingFace model * Download and import a HuggingFace model
*/ */
suspend fun importHuggingFaceModel( suspend fun importHuggingFaceModel(
model: HuggingFaceModel, downloadInfo: HuggingFaceDownloadInfo,
progressTracker: (Float) -> Unit actualSize: Long,
): ModelInfo ): Result<Unit>
} }
class InsufficientStorageException(message: String) : IOException(message) class InsufficientStorageException(message: String) : IOException(message)
@ -171,7 +177,9 @@ class ModelRepositoryImpl @Inject constructor(
val fileSize = size ?: getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File size N/A") val fileSize = size ?: getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File size N/A")
if (!hasEnoughSpaceForImport(fileSize)) { if (!hasEnoughSpaceForImport(fileSize)) {
throw InsufficientStorageException( throw InsufficientStorageException(
"Not enough storage space. Required: ${formatFileByteSize(fileSize)}, Available: ${formatFileByteSize(availableSpaceBytes)}" "Not enough storage space! " +
"Required: ${formatFileByteSize(fileSize)}, " +
"Available: ${formatFileByteSize(availableSpaceBytes)}"
) )
} }
@ -265,7 +273,7 @@ class ModelRepositoryImpl @Inject constructor(
// Add this method to ModelRepositoryImpl.kt // Add this method to ModelRepositoryImpl.kt
private fun hasEnoughSpaceForImport(fileSize: Long): Boolean { private fun hasEnoughSpaceForImport(fileSize: Long): Boolean {
val availableSpace = availableSpaceBytes val availableSpace = availableSpaceBytes
val requiredSpace = (fileSize * MODEL_IMPORT_SPACE_BUFFER_SCALE ).toLong() val requiredSpace = (fileSize * MODEL_IMPORT_SPACE_BUFFER_SCALE).toLong()
return availableSpace >= requiredSpace return availableSpace >= requiredSpace
} }
@ -343,125 +351,43 @@ class ModelRepositoryImpl @Inject constructor(
override suspend fun searchHuggingFaceModels( override suspend fun searchHuggingFaceModels(
limit: Int limit: Int
): List<HuggingFaceModel> = withContext(Dispatchers.Default) { ): List<HuggingFaceModel> = withContext(Dispatchers.IO) {
huggingFaceRemoteDataSource.searchModels(limit = limit) huggingFaceRemoteDataSource.searchModels(limit = limit)
} }
override suspend fun getHuggingFaceModelDetails( override suspend fun getHuggingFaceModelDetails(
modelId: String modelId: String
) = withContext(Dispatchers.Default) { ) = withContext(Dispatchers.IO) {
huggingFaceRemoteDataSource.getModelDetails(modelId) huggingFaceRemoteDataSource.getModelDetails(modelId)
} }
override suspend fun getHuggingFaceModelFileSize(
downloadInfo: HuggingFaceDownloadInfo,
): Long? = withContext(Dispatchers.IO) {
huggingFaceRemoteDataSource.getFileSize(downloadInfo.modelId, downloadInfo.filename)
}
override suspend fun importHuggingFaceModel( override suspend fun importHuggingFaceModel(
model: HuggingFaceModel, downloadInfo: HuggingFaceDownloadInfo,
progressTracker: (Float) -> Unit actualSize: Long,
): ModelInfo = withContext(Dispatchers.IO) { ): Result<Unit> = withContext(Dispatchers.IO) {
try { if (!hasEnoughSpaceForImport(actualSize)) {
// Find GGUF files in the model repository throw InsufficientStorageException(
val modelFiles = findGgufFiles(model.id) "Not enough storage space! " +
if (modelFiles.isEmpty()) { "Estimated required: ${formatFileByteSize(actualSize)}, " +
throw IOException("No GGUF files found for model ${model.id}") "Available: ${formatFileByteSize(availableSpaceBytes)}"
}
// Create directory for the model
val modelDir = File(modelsDir, model.id.replace("/", "_"))
if (!modelDir.exists()) {
modelDir.mkdirs()
}
var totalDownloaded = 0L
var fileIndex = 0
val totalFiles = modelFiles.size
// Download each GGUF file
modelFiles.forEach { filePath ->
fileIndex++
progressTracker(fileIndex.toFloat() / totalFiles)
val outputFile = File(modelDir, filePath.substringAfterLast("/"))
huggingFaceRemoteDataSource.downloadModelFile(
model.id, filePath, outputFile
).onSuccess { file ->
totalDownloaded += file.length()
}
}
// Create and save the model entity
val modelEntity = ModelEntity(
id = UUID.randomUUID().toString(), // Generate a new ID for local storage
name = model.modelId.split("/").last(),
path = modelDir.absolutePath,
sizeInBytes = totalDownloaded,
metadata = createGgufMetadataFromHfModel(model),
dateAdded = System.currentTimeMillis(),
dateLastUsed = null
) )
modelDao.insertModel(modelEntity)
return@withContext modelEntity.toModelInfo()
} catch (e: Exception) {
// Clean up if download fails
File(modelsDir, model.id.replace("/", "_")).let {
if (it.exists()) {
it.deleteRecursively()
}
}
throw e
} }
}
// TODO-han.yin: replace this heuristic approach with API call try {
private suspend fun findGgufFiles(modelId: String): List<String> = withContext(Dispatchers.IO) { huggingFaceRemoteDataSource.downloadModelFile(
// In a real implementation, we would query the Hugging Face API or use Git APIs context = context,
// to get the file listing. For now, we'll use a simpler approach: downloadInfo = downloadInfo,
// 1. Try standard filenames first
val standardNames = listOf(
"model.gguf",
"model-q4_0.gguf",
"model-q4_k_m.gguf",
"model-q5_k_m.gguf",
"model-q8_0.gguf"
)
// 2. Check if standard files exist (would require separate API call)
// For now, just return the standard names as a fallback
standardNames
}
/**
* Convert HuggingFace model metadata to GgufMetadata format
*
* TODO-han.yin: improve this conversion coverage
*/
private fun createGgufMetadataFromHfModel(model: HuggingFaceModel) =
GgufMetadata(
version = GgufMetadata.GgufVersion.VALIDATED_V3,
tensorCount = 0,
kvCount = 0,
basic = GgufMetadata.BasicInfo(
uuid = model.id,
name = model.modelId.split("/").last(),
nameLabel = model.modelId,
sizeLabel = extractModelSize(model.tags)
),
additional = GgufMetadata.AdditionalInfo(
type = model.pipeline_tag,
description = null,
tags = model.tags,
languages = model.tags?.filter { it.length <= 3 } // language codes
) )
) } catch (e: Exception) {
Log.e(TAG, "Import failed: ${e.message}")
private fun extractModelSize(tags: List<String>?): String? { Result.failure(e)
// Try to extract model size from tags }
if (tags == null) return null
// Common model size patterns: 7B, 13B, 70B, etc.
val sizePattern = Regex("\\d+(\\.\\d+)?[Bb]")
return tags.find { it.matches(sizePattern) }
} }
companion object { companion object {