data: implement HuggingFace data model, data source with Retrofit API
This commit is contained in:
parent
4b2f769ba8
commit
48fa0b23dc
|
|
@ -0,0 +1,29 @@
|
|||
package com.example.llama.revamp.data.remote
|
||||
|
||||
import okhttp3.ResponseBody
|
||||
import retrofit2.http.GET
|
||||
import retrofit2.http.Path
|
||||
import retrofit2.http.Query
|
||||
import retrofit2.http.Streaming
|
||||
|
||||
interface HuggingFaceApiService {
|
||||
@GET("api/models")
|
||||
suspend fun getModels(
|
||||
@Query("search") search: String? = null,
|
||||
@Query("author") author: String? = null,
|
||||
@Query("filter") filter: String? = null,
|
||||
@Query("sort") sort: String? = null,
|
||||
@Query("direction") direction: String? = null,
|
||||
@Query("limit") limit: Int? = null
|
||||
): List<HuggingFaceModel>
|
||||
|
||||
@GET("api/models/{modelId}")
|
||||
suspend fun getModelDetails(@Path("modelId") modelId: String): HuggingFaceModel
|
||||
|
||||
@GET("{modelId}/resolve/main/{filePath}")
|
||||
@Streaming
|
||||
suspend fun downloadModelFile(
|
||||
@Path("modelId") modelId: String,
|
||||
@Path("filePath") filePath: String
|
||||
): ResponseBody
|
||||
}
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
package com.example.llama.revamp.data.remote
|
||||
|
||||
data class HuggingFaceModel(
|
||||
val id: String,
|
||||
val modelId: String,
|
||||
val likes: Int?,
|
||||
val trendingScore: Int?,
|
||||
val private: Boolean?,
|
||||
val downloads: Int?,
|
||||
val tags: List<String>?,
|
||||
val pipeline_tag: String?,
|
||||
val library_name: String?,
|
||||
val createdAt: String?
|
||||
)
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
package com.example.llama.revamp.data.remote
|
||||
|
||||
import android.util.Log
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.io.File
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
|
||||
interface HuggingFaceRemoteDataSource {
|
||||
suspend fun searchModels(
|
||||
query: String? = "gguf",
|
||||
filter: String? = "text-generation", // Only generative models,
|
||||
sort: String? = "downloads",
|
||||
direction: String? = "-1",
|
||||
limit: Int? = 20
|
||||
): List<HuggingFaceModel>
|
||||
|
||||
suspend fun getModelDetails(modelId: String): HuggingFaceModel
|
||||
|
||||
suspend fun downloadModelFile(modelId: String, filePath: String, outputFile: File): Result<File>
|
||||
}
|
||||
|
||||
@Singleton
|
||||
class HuggingFaceRemoteDataSourceImpl @Inject constructor(
|
||||
private val apiService: HuggingFaceApiService
|
||||
) : HuggingFaceRemoteDataSource {
|
||||
|
||||
override suspend fun searchModels(
|
||||
query: String?,
|
||||
filter: String?,
|
||||
sort: String?,
|
||||
direction: String?,
|
||||
limit: Int?
|
||||
) = withContext(Dispatchers.IO) {
|
||||
apiService.getModels(
|
||||
search = query,
|
||||
filter = filter,
|
||||
sort = sort,
|
||||
direction = direction,
|
||||
limit = limit
|
||||
)
|
||||
}
|
||||
|
||||
override suspend fun getModelDetails(
|
||||
modelId: String
|
||||
) = withContext(Dispatchers.IO) {
|
||||
apiService.getModelDetails(modelId)
|
||||
}
|
||||
|
||||
override suspend fun downloadModelFile(
|
||||
modelId: String,
|
||||
filePath: String,
|
||||
outputFile: File
|
||||
): Result<File> = withContext(Dispatchers.IO) {
|
||||
try {
|
||||
val response = apiService.downloadModelFile(modelId, filePath)
|
||||
|
||||
// Create parent directories if needed
|
||||
outputFile.parentFile?.mkdirs()
|
||||
|
||||
// Save the file
|
||||
response.byteStream().use { input ->
|
||||
outputFile.outputStream().use { output ->
|
||||
input.copyTo(output)
|
||||
}
|
||||
}
|
||||
|
||||
Result.success(outputFile)
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error downloading file $filePath: ${e.message}")
|
||||
Result.failure(e)
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
private val TAG = HuggingFaceRemoteDataSourceImpl::class.java.simpleName
|
||||
}
|
||||
}
|
||||
|
|
@ -4,16 +4,19 @@ import android.content.Context
|
|||
import android.llama.cpp.InferenceEngine
|
||||
import android.llama.cpp.LLamaAndroid
|
||||
import com.example.llama.revamp.data.local.AppDatabase
|
||||
import com.example.llama.revamp.data.remote.HuggingFaceApiService
|
||||
import com.example.llama.revamp.data.remote.HuggingFaceRemoteDataSource
|
||||
import com.example.llama.revamp.data.remote.HuggingFaceRemoteDataSourceImpl
|
||||
import com.example.llama.revamp.data.repository.ModelRepository
|
||||
import com.example.llama.revamp.data.repository.ModelRepositoryImpl
|
||||
import com.example.llama.revamp.data.repository.SystemPromptRepository
|
||||
import com.example.llama.revamp.data.repository.SystemPromptRepositoryImpl
|
||||
import com.example.llama.revamp.engine.BenchmarkService
|
||||
import com.example.llama.revamp.engine.ConversationService
|
||||
import com.example.llama.revamp.engine.StubInferenceEngine
|
||||
import com.example.llama.revamp.engine.InferenceService
|
||||
import com.example.llama.revamp.engine.InferenceServiceImpl
|
||||
import com.example.llama.revamp.engine.ModelLoadingService
|
||||
import com.example.llama.revamp.engine.StubInferenceEngine
|
||||
import com.example.llama.revamp.monitoring.PerformanceMonitor
|
||||
import dagger.Binds
|
||||
import dagger.Module
|
||||
|
|
@ -21,6 +24,10 @@ import dagger.Provides
|
|||
import dagger.hilt.InstallIn
|
||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||
import dagger.hilt.components.SingletonComponent
|
||||
import okhttp3.OkHttpClient
|
||||
import okhttp3.logging.HttpLoggingInterceptor
|
||||
import retrofit2.Retrofit
|
||||
import retrofit2.converter.gson.GsonConverterFactory
|
||||
import javax.inject.Singleton
|
||||
|
||||
@Module
|
||||
|
|
@ -45,9 +52,13 @@ internal abstract class AppModule {
|
|||
@Binds
|
||||
abstract fun bindsSystemPromptRepository(impl: SystemPromptRepositoryImpl): SystemPromptRepository
|
||||
|
||||
@Binds
|
||||
abstract fun bindHuggingFaceRemoteDataSource(
|
||||
impl: HuggingFaceRemoteDataSourceImpl
|
||||
): HuggingFaceRemoteDataSource
|
||||
|
||||
companion object {
|
||||
@Provides
|
||||
@Singleton
|
||||
fun provideInferenceEngine(): InferenceEngine {
|
||||
val useRealEngine = true
|
||||
return if (useRealEngine) LLamaAndroid.instance() else StubInferenceEngine()
|
||||
|
|
@ -64,5 +75,22 @@ internal abstract class AppModule {
|
|||
|
||||
@Provides
|
||||
fun providesSystemPromptDao(appDatabase: AppDatabase) = appDatabase.systemPromptDao()
|
||||
|
||||
@Provides
|
||||
@Singleton
|
||||
fun provideOkhttpClient() = OkHttpClient.Builder()
|
||||
.addInterceptor(HttpLoggingInterceptor().apply {
|
||||
level = HttpLoggingInterceptor.Level.BODY
|
||||
}).build()
|
||||
|
||||
@Provides
|
||||
@Singleton
|
||||
fun provideHuggingFaceApiService(okHttpClient: OkHttpClient): HuggingFaceApiService =
|
||||
Retrofit.Builder()
|
||||
.baseUrl("https://huggingface.co/")
|
||||
.client(okHttpClient)
|
||||
.addConverterFactory(GsonConverterFactory.create())
|
||||
.build()
|
||||
.create(HuggingFaceApiService::class.java)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue