diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/remote/HuggingFaceApiService.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/remote/HuggingFaceApiService.kt new file mode 100644 index 0000000000..4e1973a0cc --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/remote/HuggingFaceApiService.kt @@ -0,0 +1,29 @@ +package com.example.llama.revamp.data.remote + +import okhttp3.ResponseBody +import retrofit2.http.GET +import retrofit2.http.Path +import retrofit2.http.Query +import retrofit2.http.Streaming + +interface HuggingFaceApiService { + @GET("api/models") + suspend fun getModels( + @Query("search") search: String? = null, + @Query("author") author: String? = null, + @Query("filter") filter: String? = null, + @Query("sort") sort: String? = null, + @Query("direction") direction: String? = null, + @Query("limit") limit: Int? = null + ): List + + @GET("api/models/{modelId}") + suspend fun getModelDetails(@Path("modelId") modelId: String): HuggingFaceModel + + @GET("{modelId}/resolve/main/{filePath}") + @Streaming + suspend fun downloadModelFile( + @Path("modelId") modelId: String, + @Path("filePath") filePath: String + ): ResponseBody +} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/remote/HuggingFaceModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/remote/HuggingFaceModel.kt new file mode 100644 index 0000000000..95d8e64d89 --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/remote/HuggingFaceModel.kt @@ -0,0 +1,14 @@ +package com.example.llama.revamp.data.remote + +data class HuggingFaceModel( + val id: String, + val modelId: String, + val likes: Int?, + val trendingScore: Int?, + val private: Boolean?, + val downloads: Int?, + val tags: List?, + val pipeline_tag: String?, + val library_name: String?, + val createdAt: String? +) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/remote/HuggingFaceRemoteDataSource.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/remote/HuggingFaceRemoteDataSource.kt new file mode 100644 index 0000000000..723a5a1bf3 --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/remote/HuggingFaceRemoteDataSource.kt @@ -0,0 +1,79 @@ +package com.example.llama.revamp.data.remote + +import android.util.Log +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext +import java.io.File +import javax.inject.Inject +import javax.inject.Singleton + +interface HuggingFaceRemoteDataSource { + suspend fun searchModels( + query: String? = "gguf", + filter: String? = "text-generation", // Only generative models, + sort: String? = "downloads", + direction: String? = "-1", + limit: Int? = 20 + ): List + + suspend fun getModelDetails(modelId: String): HuggingFaceModel + + suspend fun downloadModelFile(modelId: String, filePath: String, outputFile: File): Result +} + +@Singleton +class HuggingFaceRemoteDataSourceImpl @Inject constructor( + private val apiService: HuggingFaceApiService +) : HuggingFaceRemoteDataSource { + + override suspend fun searchModels( + query: String?, + filter: String?, + sort: String?, + direction: String?, + limit: Int? + ) = withContext(Dispatchers.IO) { + apiService.getModels( + search = query, + filter = filter, + sort = sort, + direction = direction, + limit = limit + ) + } + + override suspend fun getModelDetails( + modelId: String + ) = withContext(Dispatchers.IO) { + apiService.getModelDetails(modelId) + } + + override suspend fun downloadModelFile( + modelId: String, + filePath: String, + outputFile: File + ): Result = withContext(Dispatchers.IO) { + try { + val response = apiService.downloadModelFile(modelId, filePath) + + // Create parent directories if needed + outputFile.parentFile?.mkdirs() + + // Save the file + response.byteStream().use { input -> + outputFile.outputStream().use { output -> + input.copyTo(output) + } + } + + Result.success(outputFile) + } catch (e: Exception) { + Log.e(TAG, "Error downloading file $filePath: ${e.message}") + Result.failure(e) + } + } + + companion object { + private val TAG = HuggingFaceRemoteDataSourceImpl::class.java.simpleName + } +} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/di/AppModule.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/di/AppModule.kt index 9f63a42a76..4bada497e9 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/di/AppModule.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/di/AppModule.kt @@ -4,16 +4,19 @@ import android.content.Context import android.llama.cpp.InferenceEngine import android.llama.cpp.LLamaAndroid import com.example.llama.revamp.data.local.AppDatabase +import com.example.llama.revamp.data.remote.HuggingFaceApiService +import com.example.llama.revamp.data.remote.HuggingFaceRemoteDataSource +import com.example.llama.revamp.data.remote.HuggingFaceRemoteDataSourceImpl import com.example.llama.revamp.data.repository.ModelRepository import com.example.llama.revamp.data.repository.ModelRepositoryImpl import com.example.llama.revamp.data.repository.SystemPromptRepository import com.example.llama.revamp.data.repository.SystemPromptRepositoryImpl import com.example.llama.revamp.engine.BenchmarkService import com.example.llama.revamp.engine.ConversationService -import com.example.llama.revamp.engine.StubInferenceEngine import com.example.llama.revamp.engine.InferenceService import com.example.llama.revamp.engine.InferenceServiceImpl import com.example.llama.revamp.engine.ModelLoadingService +import com.example.llama.revamp.engine.StubInferenceEngine import com.example.llama.revamp.monitoring.PerformanceMonitor import dagger.Binds import dagger.Module @@ -21,6 +24,10 @@ import dagger.Provides import dagger.hilt.InstallIn import dagger.hilt.android.qualifiers.ApplicationContext import dagger.hilt.components.SingletonComponent +import okhttp3.OkHttpClient +import okhttp3.logging.HttpLoggingInterceptor +import retrofit2.Retrofit +import retrofit2.converter.gson.GsonConverterFactory import javax.inject.Singleton @Module @@ -45,9 +52,13 @@ internal abstract class AppModule { @Binds abstract fun bindsSystemPromptRepository(impl: SystemPromptRepositoryImpl): SystemPromptRepository + @Binds + abstract fun bindHuggingFaceRemoteDataSource( + impl: HuggingFaceRemoteDataSourceImpl + ): HuggingFaceRemoteDataSource + companion object { @Provides - @Singleton fun provideInferenceEngine(): InferenceEngine { val useRealEngine = true return if (useRealEngine) LLamaAndroid.instance() else StubInferenceEngine() @@ -64,5 +75,22 @@ internal abstract class AppModule { @Provides fun providesSystemPromptDao(appDatabase: AppDatabase) = appDatabase.systemPromptDao() + + @Provides + @Singleton + fun provideOkhttpClient() = OkHttpClient.Builder() + .addInterceptor(HttpLoggingInterceptor().apply { + level = HttpLoggingInterceptor.Level.BODY + }).build() + + @Provides + @Singleton + fun provideHuggingFaceApiService(okHttpClient: OkHttpClient): HuggingFaceApiService = + Retrofit.Builder() + .baseUrl("https://huggingface.co/") + .client(okHttpClient) + .addConverterFactory(GsonConverterFactory.create()) + .build() + .create(HuggingFaceApiService::class.java) } }