data: implement HuggingFace data model, data source with Retrofit API
This commit is contained in:
parent
4b2f769ba8
commit
48fa0b23dc
|
|
@ -0,0 +1,29 @@
|
||||||
|
package com.example.llama.revamp.data.remote
|
||||||
|
|
||||||
|
import okhttp3.ResponseBody
|
||||||
|
import retrofit2.http.GET
|
||||||
|
import retrofit2.http.Path
|
||||||
|
import retrofit2.http.Query
|
||||||
|
import retrofit2.http.Streaming
|
||||||
|
|
||||||
|
interface HuggingFaceApiService {
|
||||||
|
@GET("api/models")
|
||||||
|
suspend fun getModels(
|
||||||
|
@Query("search") search: String? = null,
|
||||||
|
@Query("author") author: String? = null,
|
||||||
|
@Query("filter") filter: String? = null,
|
||||||
|
@Query("sort") sort: String? = null,
|
||||||
|
@Query("direction") direction: String? = null,
|
||||||
|
@Query("limit") limit: Int? = null
|
||||||
|
): List<HuggingFaceModel>
|
||||||
|
|
||||||
|
@GET("api/models/{modelId}")
|
||||||
|
suspend fun getModelDetails(@Path("modelId") modelId: String): HuggingFaceModel
|
||||||
|
|
||||||
|
@GET("{modelId}/resolve/main/{filePath}")
|
||||||
|
@Streaming
|
||||||
|
suspend fun downloadModelFile(
|
||||||
|
@Path("modelId") modelId: String,
|
||||||
|
@Path("filePath") filePath: String
|
||||||
|
): ResponseBody
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
package com.example.llama.revamp.data.remote
|
||||||
|
|
||||||
|
data class HuggingFaceModel(
|
||||||
|
val id: String,
|
||||||
|
val modelId: String,
|
||||||
|
val likes: Int?,
|
||||||
|
val trendingScore: Int?,
|
||||||
|
val private: Boolean?,
|
||||||
|
val downloads: Int?,
|
||||||
|
val tags: List<String>?,
|
||||||
|
val pipeline_tag: String?,
|
||||||
|
val library_name: String?,
|
||||||
|
val createdAt: String?
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,79 @@
|
||||||
|
package com.example.llama.revamp.data.remote
|
||||||
|
|
||||||
|
import android.util.Log
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.withContext
|
||||||
|
import java.io.File
|
||||||
|
import javax.inject.Inject
|
||||||
|
import javax.inject.Singleton
|
||||||
|
|
||||||
|
interface HuggingFaceRemoteDataSource {
|
||||||
|
suspend fun searchModels(
|
||||||
|
query: String? = "gguf",
|
||||||
|
filter: String? = "text-generation", // Only generative models,
|
||||||
|
sort: String? = "downloads",
|
||||||
|
direction: String? = "-1",
|
||||||
|
limit: Int? = 20
|
||||||
|
): List<HuggingFaceModel>
|
||||||
|
|
||||||
|
suspend fun getModelDetails(modelId: String): HuggingFaceModel
|
||||||
|
|
||||||
|
suspend fun downloadModelFile(modelId: String, filePath: String, outputFile: File): Result<File>
|
||||||
|
}
|
||||||
|
|
||||||
|
@Singleton
|
||||||
|
class HuggingFaceRemoteDataSourceImpl @Inject constructor(
|
||||||
|
private val apiService: HuggingFaceApiService
|
||||||
|
) : HuggingFaceRemoteDataSource {
|
||||||
|
|
||||||
|
override suspend fun searchModels(
|
||||||
|
query: String?,
|
||||||
|
filter: String?,
|
||||||
|
sort: String?,
|
||||||
|
direction: String?,
|
||||||
|
limit: Int?
|
||||||
|
) = withContext(Dispatchers.IO) {
|
||||||
|
apiService.getModels(
|
||||||
|
search = query,
|
||||||
|
filter = filter,
|
||||||
|
sort = sort,
|
||||||
|
direction = direction,
|
||||||
|
limit = limit
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
override suspend fun getModelDetails(
|
||||||
|
modelId: String
|
||||||
|
) = withContext(Dispatchers.IO) {
|
||||||
|
apiService.getModelDetails(modelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
override suspend fun downloadModelFile(
|
||||||
|
modelId: String,
|
||||||
|
filePath: String,
|
||||||
|
outputFile: File
|
||||||
|
): Result<File> = withContext(Dispatchers.IO) {
|
||||||
|
try {
|
||||||
|
val response = apiService.downloadModelFile(modelId, filePath)
|
||||||
|
|
||||||
|
// Create parent directories if needed
|
||||||
|
outputFile.parentFile?.mkdirs()
|
||||||
|
|
||||||
|
// Save the file
|
||||||
|
response.byteStream().use { input ->
|
||||||
|
outputFile.outputStream().use { output ->
|
||||||
|
input.copyTo(output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Result.success(outputFile)
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "Error downloading file $filePath: ${e.message}")
|
||||||
|
Result.failure(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private val TAG = HuggingFaceRemoteDataSourceImpl::class.java.simpleName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -4,16 +4,19 @@ import android.content.Context
|
||||||
import android.llama.cpp.InferenceEngine
|
import android.llama.cpp.InferenceEngine
|
||||||
import android.llama.cpp.LLamaAndroid
|
import android.llama.cpp.LLamaAndroid
|
||||||
import com.example.llama.revamp.data.local.AppDatabase
|
import com.example.llama.revamp.data.local.AppDatabase
|
||||||
|
import com.example.llama.revamp.data.remote.HuggingFaceApiService
|
||||||
|
import com.example.llama.revamp.data.remote.HuggingFaceRemoteDataSource
|
||||||
|
import com.example.llama.revamp.data.remote.HuggingFaceRemoteDataSourceImpl
|
||||||
import com.example.llama.revamp.data.repository.ModelRepository
|
import com.example.llama.revamp.data.repository.ModelRepository
|
||||||
import com.example.llama.revamp.data.repository.ModelRepositoryImpl
|
import com.example.llama.revamp.data.repository.ModelRepositoryImpl
|
||||||
import com.example.llama.revamp.data.repository.SystemPromptRepository
|
import com.example.llama.revamp.data.repository.SystemPromptRepository
|
||||||
import com.example.llama.revamp.data.repository.SystemPromptRepositoryImpl
|
import com.example.llama.revamp.data.repository.SystemPromptRepositoryImpl
|
||||||
import com.example.llama.revamp.engine.BenchmarkService
|
import com.example.llama.revamp.engine.BenchmarkService
|
||||||
import com.example.llama.revamp.engine.ConversationService
|
import com.example.llama.revamp.engine.ConversationService
|
||||||
import com.example.llama.revamp.engine.StubInferenceEngine
|
|
||||||
import com.example.llama.revamp.engine.InferenceService
|
import com.example.llama.revamp.engine.InferenceService
|
||||||
import com.example.llama.revamp.engine.InferenceServiceImpl
|
import com.example.llama.revamp.engine.InferenceServiceImpl
|
||||||
import com.example.llama.revamp.engine.ModelLoadingService
|
import com.example.llama.revamp.engine.ModelLoadingService
|
||||||
|
import com.example.llama.revamp.engine.StubInferenceEngine
|
||||||
import com.example.llama.revamp.monitoring.PerformanceMonitor
|
import com.example.llama.revamp.monitoring.PerformanceMonitor
|
||||||
import dagger.Binds
|
import dagger.Binds
|
||||||
import dagger.Module
|
import dagger.Module
|
||||||
|
|
@ -21,6 +24,10 @@ import dagger.Provides
|
||||||
import dagger.hilt.InstallIn
|
import dagger.hilt.InstallIn
|
||||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||||
import dagger.hilt.components.SingletonComponent
|
import dagger.hilt.components.SingletonComponent
|
||||||
|
import okhttp3.OkHttpClient
|
||||||
|
import okhttp3.logging.HttpLoggingInterceptor
|
||||||
|
import retrofit2.Retrofit
|
||||||
|
import retrofit2.converter.gson.GsonConverterFactory
|
||||||
import javax.inject.Singleton
|
import javax.inject.Singleton
|
||||||
|
|
||||||
@Module
|
@Module
|
||||||
|
|
@ -45,9 +52,13 @@ internal abstract class AppModule {
|
||||||
@Binds
|
@Binds
|
||||||
abstract fun bindsSystemPromptRepository(impl: SystemPromptRepositoryImpl): SystemPromptRepository
|
abstract fun bindsSystemPromptRepository(impl: SystemPromptRepositoryImpl): SystemPromptRepository
|
||||||
|
|
||||||
|
@Binds
|
||||||
|
abstract fun bindHuggingFaceRemoteDataSource(
|
||||||
|
impl: HuggingFaceRemoteDataSourceImpl
|
||||||
|
): HuggingFaceRemoteDataSource
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
@Provides
|
@Provides
|
||||||
@Singleton
|
|
||||||
fun provideInferenceEngine(): InferenceEngine {
|
fun provideInferenceEngine(): InferenceEngine {
|
||||||
val useRealEngine = true
|
val useRealEngine = true
|
||||||
return if (useRealEngine) LLamaAndroid.instance() else StubInferenceEngine()
|
return if (useRealEngine) LLamaAndroid.instance() else StubInferenceEngine()
|
||||||
|
|
@ -64,5 +75,22 @@ internal abstract class AppModule {
|
||||||
|
|
||||||
@Provides
|
@Provides
|
||||||
fun providesSystemPromptDao(appDatabase: AppDatabase) = appDatabase.systemPromptDao()
|
fun providesSystemPromptDao(appDatabase: AppDatabase) = appDatabase.systemPromptDao()
|
||||||
|
|
||||||
|
@Provides
|
||||||
|
@Singleton
|
||||||
|
fun provideOkhttpClient() = OkHttpClient.Builder()
|
||||||
|
.addInterceptor(HttpLoggingInterceptor().apply {
|
||||||
|
level = HttpLoggingInterceptor.Level.BODY
|
||||||
|
}).build()
|
||||||
|
|
||||||
|
@Provides
|
||||||
|
@Singleton
|
||||||
|
fun provideHuggingFaceApiService(okHttpClient: OkHttpClient): HuggingFaceApiService =
|
||||||
|
Retrofit.Builder()
|
||||||
|
.baseUrl("https://huggingface.co/")
|
||||||
|
.client(okHttpClient)
|
||||||
|
.addConverterFactory(GsonConverterFactory.create())
|
||||||
|
.build()
|
||||||
|
.create(HuggingFaceApiService::class.java)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue