remote: refine HuggingFaceModel data class
This commit is contained in:
parent
3370bd409c
commit
fe9baace7f
|
|
@ -14,7 +14,8 @@ interface HuggingFaceApiService {
|
||||||
@Query("filter") filter: String? = null,
|
@Query("filter") filter: String? = null,
|
||||||
@Query("sort") sort: String? = null,
|
@Query("sort") sort: String? = null,
|
||||||
@Query("direction") direction: String? = null,
|
@Query("direction") direction: String? = null,
|
||||||
@Query("limit") limit: Int? = null
|
@Query("limit") limit: Int? = null,
|
||||||
|
@Query("full") full: Boolean? = null,
|
||||||
): List<HuggingFaceModel>
|
): List<HuggingFaceModel>
|
||||||
|
|
||||||
@GET("api/models/{modelId}")
|
@GET("api/models/{modelId}")
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,32 @@
|
||||||
package com.example.llama.data.remote
|
package com.example.llama.data.remote
|
||||||
|
|
||||||
|
import java.util.Date
|
||||||
|
|
||||||
data class HuggingFaceModel(
|
data class HuggingFaceModel(
|
||||||
|
val _id: String,
|
||||||
val id: String,
|
val id: String,
|
||||||
val modelId: String,
|
val modelId: String,
|
||||||
|
|
||||||
|
val author: String,
|
||||||
|
val createdAt: Date?,
|
||||||
|
val lastModified: Date?,
|
||||||
|
|
||||||
|
val library_name: String?,
|
||||||
|
val pipeline_tag: String?,
|
||||||
|
val tags: List<String>?,
|
||||||
|
|
||||||
|
val private: Boolean?,
|
||||||
|
val gated: Boolean?,
|
||||||
|
|
||||||
val likes: Int?,
|
val likes: Int?,
|
||||||
val trendingScore: Int?,
|
val trendingScore: Int?,
|
||||||
val private: Boolean?,
|
|
||||||
val downloads: Int?,
|
val downloads: Int?,
|
||||||
val tags: List<String>?,
|
|
||||||
val pipeline_tag: String?,
|
val sha: String?,
|
||||||
val library_name: String?,
|
|
||||||
val createdAt: String?
|
val siblings: List<Sibling>?,
|
||||||
|
) {
|
||||||
|
data class Sibling(
|
||||||
|
val rfilename: String,
|
||||||
)
|
)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,11 +9,12 @@ import javax.inject.Singleton
|
||||||
|
|
||||||
interface HuggingFaceRemoteDataSource {
|
interface HuggingFaceRemoteDataSource {
|
||||||
suspend fun searchModels(
|
suspend fun searchModels(
|
||||||
query: String? = "gguf",
|
query: String? = "gguf q4_0",
|
||||||
filter: String? = "text-generation", // Only generative models,
|
filter: String? = "text-generation", // Only generative models,
|
||||||
sort: String? = "downloads",
|
sort: String? = "downloads",
|
||||||
direction: String? = "-1",
|
direction: String? = "-1",
|
||||||
limit: Int? = 20
|
limit: Int? = 20,
|
||||||
|
full: Boolean = true,
|
||||||
): List<HuggingFaceModel>
|
): List<HuggingFaceModel>
|
||||||
|
|
||||||
suspend fun getModelDetails(modelId: String): HuggingFaceModelDetails
|
suspend fun getModelDetails(modelId: String): HuggingFaceModelDetails
|
||||||
|
|
@ -31,14 +32,16 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor(
|
||||||
filter: String?,
|
filter: String?,
|
||||||
sort: String?,
|
sort: String?,
|
||||||
direction: String?,
|
direction: String?,
|
||||||
limit: Int?
|
limit: Int?,
|
||||||
|
full: Boolean,
|
||||||
) = withContext(Dispatchers.IO) {
|
) = withContext(Dispatchers.IO) {
|
||||||
apiService.getModels(
|
apiService.getModels(
|
||||||
search = query,
|
search = query,
|
||||||
filter = filter,
|
filter = filter,
|
||||||
sort = sort,
|
sort = sort,
|
||||||
direction = direction,
|
direction = direction,
|
||||||
limit = limit
|
limit = limit,
|
||||||
|
full = full,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ import com.example.llama.engine.ModelLoadingService
|
||||||
import com.example.llama.engine.StubInferenceEngine
|
import com.example.llama.engine.StubInferenceEngine
|
||||||
import com.example.llama.engine.StubTierDetection
|
import com.example.llama.engine.StubTierDetection
|
||||||
import com.example.llama.monitoring.PerformanceMonitor
|
import com.example.llama.monitoring.PerformanceMonitor
|
||||||
|
import com.google.gson.GsonBuilder
|
||||||
import dagger.Binds
|
import dagger.Binds
|
||||||
import dagger.Module
|
import dagger.Module
|
||||||
import dagger.Provides
|
import dagger.Provides
|
||||||
|
|
@ -33,6 +34,8 @@ import retrofit2.Retrofit
|
||||||
import retrofit2.converter.gson.GsonConverterFactory
|
import retrofit2.converter.gson.GsonConverterFactory
|
||||||
import javax.inject.Singleton
|
import javax.inject.Singleton
|
||||||
|
|
||||||
|
private const val HUGGINGFACE_DATETIME_FORMAT = "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"
|
||||||
|
|
||||||
@Module
|
@Module
|
||||||
@InstallIn(SingletonComponent::class)
|
@InstallIn(SingletonComponent::class)
|
||||||
internal abstract class AppModule {
|
internal abstract class AppModule {
|
||||||
|
|
@ -61,7 +64,7 @@ internal abstract class AppModule {
|
||||||
): HuggingFaceRemoteDataSource
|
): HuggingFaceRemoteDataSource
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
private const val USE_STUB_ENGINE = false
|
const val USE_STUB_ENGINE = false
|
||||||
|
|
||||||
@Provides
|
@Provides
|
||||||
fun provideInferenceEngine(@ApplicationContext context: Context): InferenceEngine {
|
fun provideInferenceEngine(@ApplicationContext context: Context): InferenceEngine {
|
||||||
|
|
@ -111,7 +114,9 @@ internal abstract class AppModule {
|
||||||
Retrofit.Builder()
|
Retrofit.Builder()
|
||||||
.baseUrl("https://huggingface.co/")
|
.baseUrl("https://huggingface.co/")
|
||||||
.client(okHttpClient)
|
.client(okHttpClient)
|
||||||
.addConverterFactory(GsonConverterFactory.create())
|
.addConverterFactory(GsonConverterFactory.create(
|
||||||
|
GsonBuilder().setDateFormat(HUGGINGFACE_DATETIME_FORMAT).create()
|
||||||
|
))
|
||||||
.build()
|
.build()
|
||||||
.create(HuggingFaceApiService::class.java)
|
.create(HuggingFaceApiService::class.java)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue