data: implement HuggingFace data model, data source with Retrofit API

This commit is contained in:
Han Yin 2025-04-21 23:07:03 -07:00
parent 4b2f769ba8
commit 48fa0b23dc
4 changed files with 152 additions and 2 deletions

View File

@ -0,0 +1,29 @@
package com.example.llama.revamp.data.remote
import okhttp3.ResponseBody
import retrofit2.http.GET
import retrofit2.http.Path
import retrofit2.http.Query
import retrofit2.http.Streaming
interface HuggingFaceApiService {
@GET("api/models")
suspend fun getModels(
@Query("search") search: String? = null,
@Query("author") author: String? = null,
@Query("filter") filter: String? = null,
@Query("sort") sort: String? = null,
@Query("direction") direction: String? = null,
@Query("limit") limit: Int? = null
): List<HuggingFaceModel>
@GET("api/models/{modelId}")
suspend fun getModelDetails(@Path("modelId") modelId: String): HuggingFaceModel
@GET("{modelId}/resolve/main/{filePath}")
@Streaming
suspend fun downloadModelFile(
@Path("modelId") modelId: String,
@Path("filePath") filePath: String
): ResponseBody
}

View File

@ -0,0 +1,14 @@
package com.example.llama.revamp.data.remote
data class HuggingFaceModel(
val id: String,
val modelId: String,
val likes: Int?,
val trendingScore: Int?,
val private: Boolean?,
val downloads: Int?,
val tags: List<String>?,
val pipeline_tag: String?,
val library_name: String?,
val createdAt: String?
)

View File

@ -0,0 +1,79 @@
package com.example.llama.revamp.data.remote
import android.util.Log
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.io.File
import javax.inject.Inject
import javax.inject.Singleton
interface HuggingFaceRemoteDataSource {
suspend fun searchModels(
query: String? = "gguf",
filter: String? = "text-generation", // Only generative models,
sort: String? = "downloads",
direction: String? = "-1",
limit: Int? = 20
): List<HuggingFaceModel>
suspend fun getModelDetails(modelId: String): HuggingFaceModel
suspend fun downloadModelFile(modelId: String, filePath: String, outputFile: File): Result<File>
}
@Singleton
class HuggingFaceRemoteDataSourceImpl @Inject constructor(
private val apiService: HuggingFaceApiService
) : HuggingFaceRemoteDataSource {
override suspend fun searchModels(
query: String?,
filter: String?,
sort: String?,
direction: String?,
limit: Int?
) = withContext(Dispatchers.IO) {
apiService.getModels(
search = query,
filter = filter,
sort = sort,
direction = direction,
limit = limit
)
}
override suspend fun getModelDetails(
modelId: String
) = withContext(Dispatchers.IO) {
apiService.getModelDetails(modelId)
}
override suspend fun downloadModelFile(
modelId: String,
filePath: String,
outputFile: File
): Result<File> = withContext(Dispatchers.IO) {
try {
val response = apiService.downloadModelFile(modelId, filePath)
// Create parent directories if needed
outputFile.parentFile?.mkdirs()
// Save the file
response.byteStream().use { input ->
outputFile.outputStream().use { output ->
input.copyTo(output)
}
}
Result.success(outputFile)
} catch (e: Exception) {
Log.e(TAG, "Error downloading file $filePath: ${e.message}")
Result.failure(e)
}
}
companion object {
private val TAG = HuggingFaceRemoteDataSourceImpl::class.java.simpleName
}
}

View File

@ -4,16 +4,19 @@ import android.content.Context
import android.llama.cpp.InferenceEngine
import android.llama.cpp.LLamaAndroid
import com.example.llama.revamp.data.local.AppDatabase
import com.example.llama.revamp.data.remote.HuggingFaceApiService
import com.example.llama.revamp.data.remote.HuggingFaceRemoteDataSource
import com.example.llama.revamp.data.remote.HuggingFaceRemoteDataSourceImpl
import com.example.llama.revamp.data.repository.ModelRepository
import com.example.llama.revamp.data.repository.ModelRepositoryImpl
import com.example.llama.revamp.data.repository.SystemPromptRepository
import com.example.llama.revamp.data.repository.SystemPromptRepositoryImpl
import com.example.llama.revamp.engine.BenchmarkService
import com.example.llama.revamp.engine.ConversationService
import com.example.llama.revamp.engine.StubInferenceEngine
import com.example.llama.revamp.engine.InferenceService
import com.example.llama.revamp.engine.InferenceServiceImpl
import com.example.llama.revamp.engine.ModelLoadingService
import com.example.llama.revamp.engine.StubInferenceEngine
import com.example.llama.revamp.monitoring.PerformanceMonitor
import dagger.Binds
import dagger.Module
@ -21,6 +24,10 @@ import dagger.Provides
import dagger.hilt.InstallIn
import dagger.hilt.android.qualifiers.ApplicationContext
import dagger.hilt.components.SingletonComponent
import okhttp3.OkHttpClient
import okhttp3.logging.HttpLoggingInterceptor
import retrofit2.Retrofit
import retrofit2.converter.gson.GsonConverterFactory
import javax.inject.Singleton
@Module
@ -45,9 +52,13 @@ internal abstract class AppModule {
@Binds
abstract fun bindsSystemPromptRepository(impl: SystemPromptRepositoryImpl): SystemPromptRepository
@Binds
abstract fun bindHuggingFaceRemoteDataSource(
impl: HuggingFaceRemoteDataSourceImpl
): HuggingFaceRemoteDataSource
companion object {
@Provides
@Singleton
fun provideInferenceEngine(): InferenceEngine {
val useRealEngine = true
return if (useRealEngine) LLamaAndroid.instance() else StubInferenceEngine()
@ -64,5 +75,22 @@ internal abstract class AppModule {
@Provides
fun providesSystemPromptDao(appDatabase: AppDatabase) = appDatabase.systemPromptDao()
@Provides
@Singleton
fun provideOkhttpClient() = OkHttpClient.Builder()
.addInterceptor(HttpLoggingInterceptor().apply {
level = HttpLoggingInterceptor.Level.BODY
}).build()
@Provides
@Singleton
fun provideHuggingFaceApiService(okHttpClient: OkHttpClient): HuggingFaceApiService =
Retrofit.Builder()
.baseUrl("https://huggingface.co/")
.client(okHttpClient)
.addConverterFactory(GsonConverterFactory.create())
.build()
.create(HuggingFaceApiService::class.java)
}
}