Use common sampler

This commit is contained in:
Han Yin 2025-03-26 13:42:24 -07:00
parent 1f255d4bca
commit d4ab3832cf
2 changed files with 17 additions and 13 deletions

View File

@ -4,6 +4,7 @@
#include <cmath> #include <cmath>
#include <string> #include <string>
#include <unistd.h> #include <unistd.h>
#include <sampling.h>
#include "llama.h" #include "llama.h"
#include "common.h" #include "common.h"
@ -12,6 +13,8 @@
#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) #define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
constexpr int CONTEXT_SIZE = 2048; constexpr int CONTEXT_SIZE = 2048;
constexpr float SAMPLER_TEMP = 0.3f;
constexpr int N_THREADS_MIN = 1; constexpr int N_THREADS_MIN = 1;
constexpr int N_THREADS_MAX = 8; constexpr int N_THREADS_MAX = 8;
constexpr int N_THREADS_HEADROOM = 2; constexpr int N_THREADS_HEADROOM = 2;
@ -316,11 +319,11 @@ Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_p
extern "C" extern "C"
JNIEXPORT jlong JNICALL JNIEXPORT jlong JNICALL
Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject) { Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject, jlong model_pointer) {
auto sparams = llama_sampler_chain_default_params(); auto *const model = reinterpret_cast<llama_model *>(model_pointer);
sparams.no_perf = true; common_params_sampling sparams;
llama_sampler *smpl = llama_sampler_chain_init(sparams); sparams.temp = SAMPLER_TEMP;
llama_sampler_chain_add(smpl, llama_sampler_init_greedy()); common_sampler *smpl = common_sampler_init(model, sparams);
return reinterpret_cast<jlong>(smpl); return reinterpret_cast<jlong>(smpl);
} }
@ -366,9 +369,9 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
const auto tokens_list = common_tokenize(context, text, true, parse_special); const auto tokens_list = common_tokenize(context, text, true, parse_special);
auto n_ctx = llama_n_ctx(context); auto n_ctx = llama_n_ctx(context);
auto n_kv_req = tokens_list.size() + n_len; auto n_tokens = tokens_list.size();
auto n_kv_req = n_tokens + n_len;
LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req); LOGi("n_ctx = %d, n_kv_req = n_tokens (%d) + n_len (%d) = %d", n_ctx, n_tokens, n_len, n_kv_req);
if (n_kv_req > n_ctx) { if (n_kv_req > n_ctx) {
LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough"); LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough");
@ -410,7 +413,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
) { ) {
auto *const context = reinterpret_cast<llama_context *>(context_pointer); auto *const context = reinterpret_cast<llama_context *>(context_pointer);
auto *const batch = reinterpret_cast<llama_batch *>(batch_pointer); auto *const batch = reinterpret_cast<llama_batch *>(batch_pointer);
auto *const sampler = reinterpret_cast<llama_sampler *>(sampler_pointer); auto *const sampler = reinterpret_cast<common_sampler *>(sampler_pointer);
const auto *const model = llama_get_model(context); const auto *const model = llama_get_model(context);
const auto *const vocab = llama_model_get_vocab(model); const auto *const vocab = llama_model_get_vocab(model);
@ -419,7 +422,8 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V"); if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
// sample the most likely token // sample the most likely token
const auto new_token_id = llama_sampler_sample(sampler, context, -1); const auto new_token_id = common_sampler_sample(sampler, context, -1);
common_sampler_accept(sampler, new_token_id, /* accept_grammar= */ true);
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) { if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {

View File

@ -45,7 +45,7 @@ class LLamaAndroid {
private external fun backend_free() private external fun backend_free()
private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
private external fun free_batch(batch: Long) private external fun free_batch(batch: Long)
private external fun new_sampler(): Long private external fun new_sampler(model: Long): Long
private external fun free_sampler(sampler: Long) private external fun free_sampler(sampler: Long)
private external fun bench_model( private external fun bench_model(
context: Long, context: Long,
@ -103,7 +103,7 @@ class LLamaAndroid {
val batch = new_batch(DEFAULT_BATCH_SIZE, 0, 1) val batch = new_batch(DEFAULT_BATCH_SIZE, 0, 1)
if (batch == 0L) throw IllegalStateException("new_batch() failed") if (batch == 0L) throw IllegalStateException("new_batch() failed")
val sampler = new_sampler() val sampler = new_sampler(model)
if (sampler == 0L) throw IllegalStateException("new_sampler() failed") if (sampler == 0L) throw IllegalStateException("new_sampler() failed")
Log.i(tag, "Loaded model $pathToModel") Log.i(tag, "Loaded model $pathToModel")
@ -133,7 +133,7 @@ class LLamaAndroid {
emit(str) emit(str)
} }
kv_cache_clear(state.context) // kv_cache_clear(state.context)
} }
else -> {} else -> {}
} }