diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceRemoteDataSource.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceRemoteDataSource.kt index 6e6c1a6c3c..b26fb3d95a 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceRemoteDataSource.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceRemoteDataSource.kt @@ -7,6 +7,8 @@ import android.util.Log import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.delay import kotlinx.coroutines.withContext +import java.io.FileNotFoundException +import java.io.IOException import javax.inject.Inject import javax.inject.Singleton @@ -26,14 +28,14 @@ interface HuggingFaceRemoteDataSource { direction: String? = "-1", limit: Int? = SEARCH_RESULT_LIMIT, full: Boolean = true, - ): List + ): Result> suspend fun getModelDetails(modelId: String): HuggingFaceModelDetails /** * Obtain selected HuggingFace model's GGUF file size from HTTP header */ - suspend fun getFileSize(modelId: String, filePath: String): Long? + suspend fun getFileSize(modelId: String, filePath: String): Result /** * Download selected HuggingFace model's GGUF file via DownloadManager @@ -57,14 +59,23 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor( limit: Int?, full: Boolean, ) = withContext(Dispatchers.IO) { - apiService.getModels( - search = query, - filter = filter, - sort = sort, - direction = direction, - limit = limit, - full = full, - ).filter { it.gated != true && it.private != true && it.getGgufFilename() != null } + try { + apiService.getModels( + search = query, + filter = filter, + sort = sort, + direction = direction, + limit = limit, + full = full, + ).filter { + it.gated != true && it.private != true && it.getGgufFilename() != null + }.let { + if (it.isEmpty()) Result.failure(FileNotFoundException()) else Result.success(it) + } + } catch (e: Exception) { + Log.e(TAG, "Error searching for models on HuggingFace: ${e.message}") + Result.failure(e) + } } override suspend fun getModelDetails( @@ -76,18 +87,26 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor( override suspend fun getFileSize( modelId: String, filePath: String - ): Long? = withContext(Dispatchers.IO) { + ): Result = withContext(Dispatchers.IO) { try { - apiService.getModelFileHeader(modelId, filePath).let { - if (it.isSuccessful) { - it.headers()[HTTP_HEADER_CONTENT_LENGTH]?.toLongOrNull() + apiService.getModelFileHeader(modelId, filePath).let { resp -> + if (resp.isSuccessful) { + resp.headers()[HTTP_HEADER_CONTENT_LENGTH]?.toLongOrNull()?.let { + Result.success(it) + } ?: Result.failure(IOException("Content-Length header missing")) } else { - null + Result.failure( + when (resp.code()) { + 401 -> SecurityException("Model requires authentication") + 404 -> FileNotFoundException("Model file not found") + else -> IOException("Failed to get file info: HTTP ${resp.code()}") + } + ) } } } catch (e: Exception) { Log.e(TAG, "Error getting file size for $modelId: ${e.message}") - null + Result.failure(e) } } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/repository/ModelRepository.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/repository/ModelRepository.kt index 4a398b332f..3511365c4f 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/data/repository/ModelRepository.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/repository/ModelRepository.kt @@ -86,7 +86,7 @@ interface ModelRepository { /** * Search models on HuggingFace */ - suspend fun searchHuggingFaceModels(limit: Int = 20): List + suspend fun searchHuggingFaceModels(limit: Int = 20): Result> /** * Obtain the model details from HuggingFace @@ -96,7 +96,7 @@ interface ModelRepository { /** * Obtain the model's size from HTTP response header */ - suspend fun getHuggingFaceModelFileSize(downloadInfo: HuggingFaceDownloadInfo): Long? + suspend fun getHuggingFaceModelFileSize(downloadInfo: HuggingFaceDownloadInfo): Result /** * Download a HuggingFace model via system download manager @@ -351,7 +351,7 @@ class ModelRepositoryImpl @Inject constructor( override suspend fun searchHuggingFaceModels( limit: Int - ): List = withContext(Dispatchers.IO) { + ) = withContext(Dispatchers.IO) { huggingFaceRemoteDataSource.searchModels(limit = limit) } @@ -363,7 +363,7 @@ class ModelRepositoryImpl @Inject constructor( override suspend fun getHuggingFaceModelFileSize( downloadInfo: HuggingFaceDownloadInfo, - ): Long? = withContext(Dispatchers.IO) { + ): Result = withContext(Dispatchers.IO) { huggingFaceRemoteDataSource.getFileSize(downloadInfo.modelId, downloadInfo.filename) } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt index 6223c0a601..dea30c0f49 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt @@ -38,6 +38,9 @@ import kotlinx.coroutines.flow.combine import kotlinx.coroutines.flow.update import kotlinx.coroutines.launch import java.io.FileNotFoundException +import java.io.IOException +import java.net.SocketTimeoutException +import java.net.UnknownHostException import javax.inject.Inject @HiltViewModel @@ -246,11 +249,23 @@ class ModelsManagementViewModel @Inject constructor( huggingFaceQueryJob = viewModelScope.launch { _managementState.emit(Download.Querying) try { - val models = modelRepository.searchHuggingFaceModels() - Log.d(TAG, "Fetched ${models.size} models from HuggingFace:") - _managementState.emit(Download.Ready(models)) + modelRepository.searchHuggingFaceModels().fold( + onSuccess = { models -> + Log.d(TAG, "Fetched ${models.size} models from HuggingFace:") + _managementState.emit(Download.Ready(models)) + }, + onFailure = { throw it } + ) } catch (_: CancellationException) { // no-op + } catch (_: UnknownHostException) { + _managementState.value = Download.Error(message = "No internet connection") + } catch (_: SocketTimeoutException) { + _managementState.value = Download.Error(message = "Connection timed out") + } catch (e: IOException) { + _managementState.value = Download.Error(message = "Network error: ${e.message}") + } catch (_: FileNotFoundException) { + _managementState.emit(Download.Error(message = "No eligible models")) } catch (e: Exception) { _managementState.emit(Download.Error(message = e.message ?: "Unknown error")) } @@ -267,17 +282,24 @@ class ModelsManagementViewModel @Inject constructor( val downloadInfo = model.toDownloadInfo() requireNotNull(downloadInfo) { "Download URL is missing!" } - val actualSize = modelRepository.getHuggingFaceModelFileSize(downloadInfo) - requireNotNull(actualSize) { "Unknown model file size!" } - Log.d(TAG, "Model file size: ${formatFileByteSize(actualSize)}") - - modelRepository.downloadHuggingFaceModel(downloadInfo, actualSize) - .onSuccess { downloadId -> - activeDownloads[downloadId] = model - _managementState.value = Download.Dispatched(downloadInfo) - } - .onFailure { throw it } - + modelRepository.getHuggingFaceModelFileSize(downloadInfo).fold( + onSuccess = { actualSize -> + Log.d(TAG, "Model file size: ${formatFileByteSize(actualSize)}") + modelRepository.downloadHuggingFaceModel(downloadInfo, actualSize) + .onSuccess { downloadId -> + activeDownloads[downloadId] = model + _managementState.value = Download.Dispatched(downloadInfo) + } + .onFailure { throw it } + }, + onFailure = { throw it } + ) + } catch (_: UnknownHostException) { + _managementState.value = Download.Error(message = "No internet connection") + } catch (_: SocketTimeoutException) { + _managementState.value = Download.Error(message = "Connection timed out") + } catch (e: IOException) { + _managementState.value = Download.Error(message = "Network error: ${e.message}") } catch (e: InsufficientStorageException) { _managementState.value = Download.Error( message = e.message ?: "Insufficient storage space to download ${model.modelId}", @@ -343,7 +365,6 @@ class ModelsManagementViewModel @Inject constructor( companion object { private val TAG = ModelsManagementViewModel::class.java.simpleName - private const val SUBSCRIPTION_TIMEOUT_MS = 5000L private const val SUCCESS_RESET_TIMEOUT_MS = 1000L } }