diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index c1c8fa380f..458fdbdc60 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include "llama.h" #include "common.h" @@ -12,6 +13,8 @@ #define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) constexpr int CONTEXT_SIZE = 2048; +constexpr float SAMPLER_TEMP = 0.3f; + constexpr int N_THREADS_MIN = 1; constexpr int N_THREADS_MAX = 8; constexpr int N_THREADS_HEADROOM = 2; @@ -316,11 +319,11 @@ Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_p extern "C" JNIEXPORT jlong JNICALL -Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject) { - auto sparams = llama_sampler_chain_default_params(); - sparams.no_perf = true; - llama_sampler *smpl = llama_sampler_chain_init(sparams); - llama_sampler_chain_add(smpl, llama_sampler_init_greedy()); +Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject, jlong model_pointer) { + auto *const model = reinterpret_cast(model_pointer); + common_params_sampling sparams; + sparams.temp = SAMPLER_TEMP; + common_sampler *smpl = common_sampler_init(model, sparams); return reinterpret_cast(smpl); } @@ -366,9 +369,9 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( const auto tokens_list = common_tokenize(context, text, true, parse_special); auto n_ctx = llama_n_ctx(context); - auto n_kv_req = tokens_list.size() + n_len; - - LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req); + auto n_tokens = tokens_list.size(); + auto n_kv_req = n_tokens + n_len; + LOGi("n_ctx = %d, n_kv_req = n_tokens (%d) + n_len (%d) = %d", n_ctx, n_tokens, n_len, n_kv_req); if (n_kv_req > n_ctx) { LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough"); @@ -410,7 +413,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( ) { auto *const context = reinterpret_cast(context_pointer); auto *const batch = reinterpret_cast(batch_pointer); - auto *const sampler = reinterpret_cast(sampler_pointer); + auto *const sampler = reinterpret_cast(sampler_pointer); const auto *const model = llama_get_model(context); const auto *const vocab = llama_model_get_vocab(model); @@ -419,7 +422,8 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V"); // sample the most likely token - const auto new_token_id = llama_sampler_sample(sampler, context, -1); + const auto new_token_id = common_sampler_sample(sampler, context, -1); + common_sampler_accept(sampler, new_token_id, /* accept_grammar= */ true); const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) { diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt index 95a8177426..65eedd2d41 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt @@ -45,7 +45,7 @@ class LLamaAndroid { private external fun backend_free() private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long private external fun free_batch(batch: Long) - private external fun new_sampler(): Long + private external fun new_sampler(model: Long): Long private external fun free_sampler(sampler: Long) private external fun bench_model( context: Long, @@ -103,7 +103,7 @@ class LLamaAndroid { val batch = new_batch(DEFAULT_BATCH_SIZE, 0, 1) if (batch == 0L) throw IllegalStateException("new_batch() failed") - val sampler = new_sampler() + val sampler = new_sampler(model) if (sampler == 0L) throw IllegalStateException("new_sampler() failed") Log.i(tag, "Loaded model $pathToModel") @@ -133,7 +133,7 @@ class LLamaAndroid { emit(str) } - kv_cache_clear(state.context) + // kv_cache_clear(state.context) } else -> {} }