remote: refactor HuggingFaceApiService; implement download feature in HuggingFaceRemoteDataSource
This commit is contained in:
parent
5138cb6a85
commit
aa22467e01
|
|
@ -1,7 +1,9 @@
|
||||||
package com.example.llama.data.remote
|
package com.example.llama.data.remote
|
||||||
|
|
||||||
import okhttp3.ResponseBody
|
import okhttp3.ResponseBody
|
||||||
|
import retrofit2.Response
|
||||||
import retrofit2.http.GET
|
import retrofit2.http.GET
|
||||||
|
import retrofit2.http.HEAD
|
||||||
import retrofit2.http.Path
|
import retrofit2.http.Path
|
||||||
import retrofit2.http.Query
|
import retrofit2.http.Query
|
||||||
import retrofit2.http.Streaming
|
import retrofit2.http.Streaming
|
||||||
|
|
@ -21,10 +23,17 @@ interface HuggingFaceApiService {
|
||||||
@GET("api/models/{modelId}")
|
@GET("api/models/{modelId}")
|
||||||
suspend fun getModelDetails(@Path("modelId") modelId: String): HuggingFaceModelDetails
|
suspend fun getModelDetails(@Path("modelId") modelId: String): HuggingFaceModelDetails
|
||||||
|
|
||||||
@GET("{modelId}/resolve/main/{filePath}")
|
@HEAD("{modelId}/resolve/main/{filename}")
|
||||||
|
suspend fun getModelFileHeader(
|
||||||
|
@Path("modelId", encoded = true) modelId: String,
|
||||||
|
@Path("filename", encoded = true) filename: String
|
||||||
|
): Response<Void>
|
||||||
|
|
||||||
|
@Deprecated("Use DownloadManager instead!")
|
||||||
|
@GET("{modelId}/resolve/main/{filename}")
|
||||||
@Streaming
|
@Streaming
|
||||||
suspend fun downloadModelFile(
|
suspend fun downloadModelFile(
|
||||||
@Path("modelId") modelId: String,
|
@Path("modelId") modelId: String,
|
||||||
@Path("filePath") filePath: String
|
@Path("filename") filename: String
|
||||||
): ResponseBody
|
): ResponseBody
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,31 +1,50 @@
|
||||||
package com.example.llama.data.remote
|
package com.example.llama.data.remote
|
||||||
|
|
||||||
|
import android.net.Uri
|
||||||
|
import androidx.core.net.toUri
|
||||||
|
import com.example.llama.di.HUGGINGFACE_HOST
|
||||||
import java.util.Date
|
import java.util.Date
|
||||||
|
|
||||||
|
private const val FILE_EXTENSION_GGUF = ".gguf"
|
||||||
|
|
||||||
data class HuggingFaceModel(
|
data class HuggingFaceModel(
|
||||||
val _id: String,
|
val _id: String,
|
||||||
val id: String,
|
val id: String,
|
||||||
val modelId: String,
|
val modelId: String,
|
||||||
|
|
||||||
val author: String,
|
val author: String,
|
||||||
val createdAt: Date?,
|
val createdAt: Date,
|
||||||
val lastModified: Date?,
|
val lastModified: Date,
|
||||||
|
|
||||||
|
val pipeline_tag: String,
|
||||||
|
val tags: List<String>,
|
||||||
|
|
||||||
|
val private: Boolean,
|
||||||
|
val gated: Boolean,
|
||||||
|
|
||||||
|
val likes: Int,
|
||||||
|
val downloads: Int,
|
||||||
|
|
||||||
|
val sha: String,
|
||||||
|
val siblings: List<Sibling>,
|
||||||
|
|
||||||
val library_name: String?,
|
val library_name: String?,
|
||||||
val pipeline_tag: String?,
|
|
||||||
val tags: List<String>?,
|
|
||||||
|
|
||||||
val private: Boolean?,
|
|
||||||
val gated: Boolean?,
|
|
||||||
|
|
||||||
val likes: Int?,
|
|
||||||
val downloads: Int?,
|
|
||||||
|
|
||||||
val sha: String?,
|
|
||||||
|
|
||||||
val siblings: List<Sibling>?,
|
|
||||||
) {
|
) {
|
||||||
data class Sibling(
|
data class Sibling(
|
||||||
val rfilename: String,
|
val rfilename: String,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
fun getGgufFilename(): String? =
|
||||||
|
siblings.map { it.rfilename }.first { it.endsWith(FILE_EXTENSION_GGUF) }
|
||||||
|
|
||||||
|
fun toDownloadInfo() = getGgufFilename()?.let { HuggingFaceDownloadInfo(_id, modelId, it) }
|
||||||
|
}
|
||||||
|
|
||||||
|
data class HuggingFaceDownloadInfo(
|
||||||
|
val _id: String,
|
||||||
|
val modelId: String,
|
||||||
|
val filename: String,
|
||||||
|
) {
|
||||||
|
val uri: Uri
|
||||||
|
get() = "$HUGGINGFACE_HOST${modelId}/resolve/main/$filename".toUri()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,25 +1,47 @@
|
||||||
package com.example.llama.data.remote
|
package com.example.llama.data.remote
|
||||||
|
|
||||||
|
import android.app.DownloadManager
|
||||||
|
import android.content.Context
|
||||||
|
import android.os.Environment
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.delay
|
||||||
import kotlinx.coroutines.withContext
|
import kotlinx.coroutines.withContext
|
||||||
import java.io.File
|
|
||||||
import javax.inject.Inject
|
import javax.inject.Inject
|
||||||
import javax.inject.Singleton
|
import javax.inject.Singleton
|
||||||
|
|
||||||
|
private const val QUERY_Q4_0_GGUF = "gguf q4_0"
|
||||||
|
private const val FILTER_TEXT_GENERATION = "text-generation"
|
||||||
|
private const val SORT_BY_DOWNLOADS = "downloads"
|
||||||
|
private const val SEARCH_RESULT_LIMIT = 20
|
||||||
|
|
||||||
interface HuggingFaceRemoteDataSource {
|
interface HuggingFaceRemoteDataSource {
|
||||||
|
/**
|
||||||
|
* Query openly available Q4_0 GGUF models on HuggingFace
|
||||||
|
*/
|
||||||
suspend fun searchModels(
|
suspend fun searchModels(
|
||||||
query: String? = "gguf q4_0",
|
query: String? = QUERY_Q4_0_GGUF,
|
||||||
filter: String? = "text-generation", // Only generative models,
|
filter: String? = FILTER_TEXT_GENERATION,
|
||||||
sort: String? = "downloads",
|
sort: String? = SORT_BY_DOWNLOADS,
|
||||||
direction: String? = "-1",
|
direction: String? = "-1",
|
||||||
limit: Int? = 20,
|
limit: Int? = SEARCH_RESULT_LIMIT,
|
||||||
full: Boolean = true,
|
full: Boolean = true,
|
||||||
): List<HuggingFaceModel>
|
): List<HuggingFaceModel>
|
||||||
|
|
||||||
suspend fun getModelDetails(modelId: String): HuggingFaceModelDetails
|
suspend fun getModelDetails(modelId: String): HuggingFaceModelDetails
|
||||||
|
|
||||||
suspend fun downloadModelFile(modelId: String, filePath: String, outputFile: File): Result<File>
|
/**
|
||||||
|
* Obtain selected HuggingFace model's GGUF file size from HTTP header
|
||||||
|
*/
|
||||||
|
suspend fun getFileSize(modelId: String, filePath: String): Long?
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Download selected HuggingFace model's GGUF file via DownloadManager
|
||||||
|
*/
|
||||||
|
suspend fun downloadModelFile(
|
||||||
|
context: Context,
|
||||||
|
downloadInfo: HuggingFaceDownloadInfo,
|
||||||
|
): Result<Unit>
|
||||||
}
|
}
|
||||||
|
|
||||||
@Singleton
|
@Singleton
|
||||||
|
|
@ -42,7 +64,7 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor(
|
||||||
direction = direction,
|
direction = direction,
|
||||||
limit = limit,
|
limit = limit,
|
||||||
full = full,
|
full = full,
|
||||||
)
|
).filter { it.gated != true && it.private != true }
|
||||||
}
|
}
|
||||||
|
|
||||||
override suspend fun getModelDetails(
|
override suspend fun getModelDetails(
|
||||||
|
|
@ -51,32 +73,98 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor(
|
||||||
apiService.getModelDetails(modelId)
|
apiService.getModelDetails(modelId)
|
||||||
}
|
}
|
||||||
|
|
||||||
override suspend fun downloadModelFile(
|
override suspend fun getFileSize(
|
||||||
modelId: String,
|
modelId: String,
|
||||||
filePath: String,
|
filePath: String
|
||||||
outputFile: File
|
): Long? = withContext(Dispatchers.IO) {
|
||||||
): Result<File> = withContext(Dispatchers.IO) {
|
|
||||||
try {
|
try {
|
||||||
val response = apiService.downloadModelFile(modelId, filePath)
|
apiService.getModelFileHeader(modelId, filePath).let {
|
||||||
|
if (it.isSuccessful) {
|
||||||
// Create parent directories if needed
|
it.headers()[HTTP_HEADER_CONTENT_LENGTH]?.toLongOrNull()
|
||||||
outputFile.parentFile?.mkdirs()
|
} else {
|
||||||
|
null
|
||||||
// Save the file
|
|
||||||
response.byteStream().use { input ->
|
|
||||||
outputFile.outputStream().use { output ->
|
|
||||||
input.copyTo(output)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Result.success(outputFile)
|
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
Log.e(TAG, "Error downloading file $filePath: ${e.message}")
|
Log.e(TAG, "Error getting file size for $modelId: ${e.message}")
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override suspend fun downloadModelFile(
|
||||||
|
context: Context,
|
||||||
|
downloadInfo: HuggingFaceDownloadInfo,
|
||||||
|
): Result<Unit> = withContext(Dispatchers.IO) {
|
||||||
|
try {
|
||||||
|
val downloadManager =
|
||||||
|
context.getSystemService(Context.DOWNLOAD_SERVICE) as DownloadManager
|
||||||
|
val request = DownloadManager.Request(downloadInfo.uri).apply {
|
||||||
|
setTitle("HuggingFace model download")
|
||||||
|
setDescription("Downloading ${downloadInfo.filename}")
|
||||||
|
setNotificationVisibility(DownloadManager.Request.VISIBILITY_VISIBLE_NOTIFY_COMPLETED)
|
||||||
|
setDestinationInExternalPublicDir(
|
||||||
|
Environment.DIRECTORY_DOWNLOADS,
|
||||||
|
downloadInfo.filename
|
||||||
|
)
|
||||||
|
setAllowedNetworkTypes(
|
||||||
|
DownloadManager.Request.NETWORK_WIFI or DownloadManager.Request.NETWORK_MOBILE
|
||||||
|
)
|
||||||
|
setAllowedOverMetered(true)
|
||||||
|
setAllowedOverRoaming(false)
|
||||||
|
}
|
||||||
|
Log.d(TAG, "Enqueuing download request for: ${downloadInfo.modelId}")
|
||||||
|
val downloadId = downloadManager.enqueue(request)
|
||||||
|
|
||||||
|
delay(DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY)
|
||||||
|
|
||||||
|
val cursor = downloadManager.query(DownloadManager.Query().setFilterById(downloadId))
|
||||||
|
if (cursor != null && cursor.moveToFirst()) {
|
||||||
|
val statusIndex = cursor.getColumnIndex(DownloadManager.COLUMN_STATUS)
|
||||||
|
if (statusIndex >= 0) {
|
||||||
|
val status = cursor.getInt(statusIndex)
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
when (status) {
|
||||||
|
DownloadManager.STATUS_FAILED -> {
|
||||||
|
// Get failure reason if available
|
||||||
|
val reasonIndex = cursor.getColumnIndex(DownloadManager.COLUMN_REASON)
|
||||||
|
val reason = if (reasonIndex >= 0) cursor.getInt(reasonIndex) else -1
|
||||||
|
val errorMessage = when (reason) {
|
||||||
|
DownloadManager.ERROR_HTTP_DATA_ERROR -> "HTTP error"
|
||||||
|
DownloadManager.ERROR_INSUFFICIENT_SPACE -> "Insufficient storage"
|
||||||
|
DownloadManager.ERROR_TOO_MANY_REDIRECTS -> "Too many redirects"
|
||||||
|
DownloadManager.ERROR_UNHANDLED_HTTP_CODE -> "Unhandled HTTP code"
|
||||||
|
DownloadManager.ERROR_CANNOT_RESUME -> "Cannot resume download"
|
||||||
|
DownloadManager.ERROR_FILE_ERROR -> "File error"
|
||||||
|
else -> "Unknown error"
|
||||||
|
}
|
||||||
|
Result.failure(Exception(errorMessage))
|
||||||
|
}
|
||||||
|
else -> {
|
||||||
|
// Download is pending, paused, or running
|
||||||
|
Result.success(Unit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Assume success if we can't check status
|
||||||
|
cursor.close()
|
||||||
|
Result.success(Unit)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Assume success if cursor is empty
|
||||||
|
cursor?.close()
|
||||||
|
Result.success(Unit)
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "Failed to enqueue download: ${e.message}")
|
||||||
Result.failure(e)
|
Result.failure(e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
private val TAG = HuggingFaceRemoteDataSourceImpl::class.java.simpleName
|
private val TAG = HuggingFaceRemoteDataSourceImpl::class.java.simpleName
|
||||||
|
|
||||||
|
private const val HTTP_HEADER_CONTENT_LENGTH = "content-length"
|
||||||
|
private const val DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY = 500L
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ import com.example.llama.data.local.dao.ModelDao
|
||||||
import com.example.llama.data.local.entity.ModelEntity
|
import com.example.llama.data.local.entity.ModelEntity
|
||||||
import com.example.llama.data.model.GgufMetadata
|
import com.example.llama.data.model.GgufMetadata
|
||||||
import com.example.llama.data.model.ModelInfo
|
import com.example.llama.data.model.ModelInfo
|
||||||
|
import com.example.llama.data.remote.HuggingFaceDownloadInfo
|
||||||
import com.example.llama.data.remote.HuggingFaceModel
|
import com.example.llama.data.remote.HuggingFaceModel
|
||||||
import com.example.llama.data.remote.HuggingFaceModelDetails
|
import com.example.llama.data.remote.HuggingFaceModelDetails
|
||||||
import com.example.llama.data.remote.HuggingFaceRemoteDataSource
|
import com.example.llama.data.remote.HuggingFaceRemoteDataSource
|
||||||
|
|
@ -92,13 +93,18 @@ interface ModelRepository {
|
||||||
*/
|
*/
|
||||||
suspend fun getHuggingFaceModelDetails(modelId: String): HuggingFaceModelDetails
|
suspend fun getHuggingFaceModelDetails(modelId: String): HuggingFaceModelDetails
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Obtain the model's size from HTTP response header
|
||||||
|
*/
|
||||||
|
suspend fun getHuggingFaceModelFileSize(downloadInfo: HuggingFaceDownloadInfo): Long?
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Download and import a HuggingFace model
|
* Download and import a HuggingFace model
|
||||||
*/
|
*/
|
||||||
suspend fun importHuggingFaceModel(
|
suspend fun importHuggingFaceModel(
|
||||||
model: HuggingFaceModel,
|
downloadInfo: HuggingFaceDownloadInfo,
|
||||||
progressTracker: (Float) -> Unit
|
actualSize: Long,
|
||||||
): ModelInfo
|
): Result<Unit>
|
||||||
}
|
}
|
||||||
|
|
||||||
class InsufficientStorageException(message: String) : IOException(message)
|
class InsufficientStorageException(message: String) : IOException(message)
|
||||||
|
|
@ -171,7 +177,9 @@ class ModelRepositoryImpl @Inject constructor(
|
||||||
val fileSize = size ?: getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File size N/A")
|
val fileSize = size ?: getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File size N/A")
|
||||||
if (!hasEnoughSpaceForImport(fileSize)) {
|
if (!hasEnoughSpaceForImport(fileSize)) {
|
||||||
throw InsufficientStorageException(
|
throw InsufficientStorageException(
|
||||||
"Not enough storage space. Required: ${formatFileByteSize(fileSize)}, Available: ${formatFileByteSize(availableSpaceBytes)}"
|
"Not enough storage space! " +
|
||||||
|
"Required: ${formatFileByteSize(fileSize)}, " +
|
||||||
|
"Available: ${formatFileByteSize(availableSpaceBytes)}"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -265,7 +273,7 @@ class ModelRepositoryImpl @Inject constructor(
|
||||||
// Add this method to ModelRepositoryImpl.kt
|
// Add this method to ModelRepositoryImpl.kt
|
||||||
private fun hasEnoughSpaceForImport(fileSize: Long): Boolean {
|
private fun hasEnoughSpaceForImport(fileSize: Long): Boolean {
|
||||||
val availableSpace = availableSpaceBytes
|
val availableSpace = availableSpaceBytes
|
||||||
val requiredSpace = (fileSize * MODEL_IMPORT_SPACE_BUFFER_SCALE ).toLong()
|
val requiredSpace = (fileSize * MODEL_IMPORT_SPACE_BUFFER_SCALE).toLong()
|
||||||
return availableSpace >= requiredSpace
|
return availableSpace >= requiredSpace
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -343,126 +351,44 @@ class ModelRepositoryImpl @Inject constructor(
|
||||||
|
|
||||||
override suspend fun searchHuggingFaceModels(
|
override suspend fun searchHuggingFaceModels(
|
||||||
limit: Int
|
limit: Int
|
||||||
): List<HuggingFaceModel> = withContext(Dispatchers.Default) {
|
): List<HuggingFaceModel> = withContext(Dispatchers.IO) {
|
||||||
huggingFaceRemoteDataSource.searchModels(limit = limit)
|
huggingFaceRemoteDataSource.searchModels(limit = limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
override suspend fun getHuggingFaceModelDetails(
|
override suspend fun getHuggingFaceModelDetails(
|
||||||
modelId: String
|
modelId: String
|
||||||
) = withContext(Dispatchers.Default) {
|
) = withContext(Dispatchers.IO) {
|
||||||
huggingFaceRemoteDataSource.getModelDetails(modelId)
|
huggingFaceRemoteDataSource.getModelDetails(modelId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override suspend fun getHuggingFaceModelFileSize(
|
||||||
|
downloadInfo: HuggingFaceDownloadInfo,
|
||||||
|
): Long? = withContext(Dispatchers.IO) {
|
||||||
|
huggingFaceRemoteDataSource.getFileSize(downloadInfo.modelId, downloadInfo.filename)
|
||||||
|
}
|
||||||
|
|
||||||
override suspend fun importHuggingFaceModel(
|
override suspend fun importHuggingFaceModel(
|
||||||
model: HuggingFaceModel,
|
downloadInfo: HuggingFaceDownloadInfo,
|
||||||
progressTracker: (Float) -> Unit
|
actualSize: Long,
|
||||||
): ModelInfo = withContext(Dispatchers.IO) {
|
): Result<Unit> = withContext(Dispatchers.IO) {
|
||||||
|
if (!hasEnoughSpaceForImport(actualSize)) {
|
||||||
|
throw InsufficientStorageException(
|
||||||
|
"Not enough storage space! " +
|
||||||
|
"Estimated required: ${formatFileByteSize(actualSize)}, " +
|
||||||
|
"Available: ${formatFileByteSize(availableSpaceBytes)}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Find GGUF files in the model repository
|
|
||||||
val modelFiles = findGgufFiles(model.id)
|
|
||||||
if (modelFiles.isEmpty()) {
|
|
||||||
throw IOException("No GGUF files found for model ${model.id}")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create directory for the model
|
|
||||||
val modelDir = File(modelsDir, model.id.replace("/", "_"))
|
|
||||||
if (!modelDir.exists()) {
|
|
||||||
modelDir.mkdirs()
|
|
||||||
}
|
|
||||||
|
|
||||||
var totalDownloaded = 0L
|
|
||||||
var fileIndex = 0
|
|
||||||
val totalFiles = modelFiles.size
|
|
||||||
|
|
||||||
// Download each GGUF file
|
|
||||||
modelFiles.forEach { filePath ->
|
|
||||||
fileIndex++
|
|
||||||
progressTracker(fileIndex.toFloat() / totalFiles)
|
|
||||||
|
|
||||||
val outputFile = File(modelDir, filePath.substringAfterLast("/"))
|
|
||||||
huggingFaceRemoteDataSource.downloadModelFile(
|
huggingFaceRemoteDataSource.downloadModelFile(
|
||||||
model.id, filePath, outputFile
|
context = context,
|
||||||
).onSuccess { file ->
|
downloadInfo = downloadInfo,
|
||||||
totalDownloaded += file.length()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create and save the model entity
|
|
||||||
val modelEntity = ModelEntity(
|
|
||||||
id = UUID.randomUUID().toString(), // Generate a new ID for local storage
|
|
||||||
name = model.modelId.split("/").last(),
|
|
||||||
path = modelDir.absolutePath,
|
|
||||||
sizeInBytes = totalDownloaded,
|
|
||||||
metadata = createGgufMetadataFromHfModel(model),
|
|
||||||
dateAdded = System.currentTimeMillis(),
|
|
||||||
dateLastUsed = null
|
|
||||||
)
|
)
|
||||||
|
|
||||||
modelDao.insertModel(modelEntity)
|
|
||||||
return@withContext modelEntity.toModelInfo()
|
|
||||||
|
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
// Clean up if download fails
|
Log.e(TAG, "Import failed: ${e.message}")
|
||||||
File(modelsDir, model.id.replace("/", "_")).let {
|
Result.failure(e)
|
||||||
if (it.exists()) {
|
|
||||||
it.deleteRecursively()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
throw e
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO-han.yin: replace this heuristic approach with API call
|
|
||||||
private suspend fun findGgufFiles(modelId: String): List<String> = withContext(Dispatchers.IO) {
|
|
||||||
// In a real implementation, we would query the Hugging Face API or use Git APIs
|
|
||||||
// to get the file listing. For now, we'll use a simpler approach:
|
|
||||||
|
|
||||||
// 1. Try standard filenames first
|
|
||||||
val standardNames = listOf(
|
|
||||||
"model.gguf",
|
|
||||||
"model-q4_0.gguf",
|
|
||||||
"model-q4_k_m.gguf",
|
|
||||||
"model-q5_k_m.gguf",
|
|
||||||
"model-q8_0.gguf"
|
|
||||||
)
|
|
||||||
|
|
||||||
// 2. Check if standard files exist (would require separate API call)
|
|
||||||
// For now, just return the standard names as a fallback
|
|
||||||
standardNames
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert HuggingFace model metadata to GgufMetadata format
|
|
||||||
*
|
|
||||||
* TODO-han.yin: improve this conversion coverage
|
|
||||||
*/
|
|
||||||
private fun createGgufMetadataFromHfModel(model: HuggingFaceModel) =
|
|
||||||
GgufMetadata(
|
|
||||||
version = GgufMetadata.GgufVersion.VALIDATED_V3,
|
|
||||||
tensorCount = 0,
|
|
||||||
kvCount = 0,
|
|
||||||
basic = GgufMetadata.BasicInfo(
|
|
||||||
uuid = model.id,
|
|
||||||
name = model.modelId.split("/").last(),
|
|
||||||
nameLabel = model.modelId,
|
|
||||||
sizeLabel = extractModelSize(model.tags)
|
|
||||||
),
|
|
||||||
additional = GgufMetadata.AdditionalInfo(
|
|
||||||
type = model.pipeline_tag,
|
|
||||||
description = null,
|
|
||||||
tags = model.tags,
|
|
||||||
languages = model.tags?.filter { it.length <= 3 } // language codes
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
private fun extractModelSize(tags: List<String>?): String? {
|
|
||||||
// Try to extract model size from tags
|
|
||||||
if (tags == null) return null
|
|
||||||
|
|
||||||
// Common model size patterns: 7B, 13B, 70B, etc.
|
|
||||||
val sizePattern = Regex("\\d+(\\.\\d+)?[Bb]")
|
|
||||||
return tags.find { it.matches(sizePattern) }
|
|
||||||
}
|
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
private val TAG = ModelRepository::class.java.simpleName
|
private val TAG = ModelRepository::class.java.simpleName
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue