Use common sampler
This commit is contained in:
parent
1f255d4bca
commit
d4ab3832cf
|
|
@ -4,6 +4,7 @@
|
|||
#include <cmath>
|
||||
#include <string>
|
||||
#include <unistd.h>
|
||||
#include <sampling.h>
|
||||
#include "llama.h"
|
||||
#include "common.h"
|
||||
|
||||
|
|
@ -12,6 +13,8 @@
|
|||
#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
|
||||
|
||||
constexpr int CONTEXT_SIZE = 2048;
|
||||
constexpr float SAMPLER_TEMP = 0.3f;
|
||||
|
||||
constexpr int N_THREADS_MIN = 1;
|
||||
constexpr int N_THREADS_MAX = 8;
|
||||
constexpr int N_THREADS_HEADROOM = 2;
|
||||
|
|
@ -316,11 +319,11 @@ Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_p
|
|||
|
||||
extern "C"
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject) {
|
||||
auto sparams = llama_sampler_chain_default_params();
|
||||
sparams.no_perf = true;
|
||||
llama_sampler *smpl = llama_sampler_chain_init(sparams);
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
|
||||
Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject, jlong model_pointer) {
|
||||
auto *const model = reinterpret_cast<llama_model *>(model_pointer);
|
||||
common_params_sampling sparams;
|
||||
sparams.temp = SAMPLER_TEMP;
|
||||
common_sampler *smpl = common_sampler_init(model, sparams);
|
||||
|
||||
return reinterpret_cast<jlong>(smpl);
|
||||
}
|
||||
|
|
@ -366,9 +369,9 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
|
|||
const auto tokens_list = common_tokenize(context, text, true, parse_special);
|
||||
|
||||
auto n_ctx = llama_n_ctx(context);
|
||||
auto n_kv_req = tokens_list.size() + n_len;
|
||||
|
||||
LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req);
|
||||
auto n_tokens = tokens_list.size();
|
||||
auto n_kv_req = n_tokens + n_len;
|
||||
LOGi("n_ctx = %d, n_kv_req = n_tokens (%d) + n_len (%d) = %d", n_ctx, n_tokens, n_len, n_kv_req);
|
||||
|
||||
if (n_kv_req > n_ctx) {
|
||||
LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough");
|
||||
|
|
@ -410,7 +413,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
|||
) {
|
||||
auto *const context = reinterpret_cast<llama_context *>(context_pointer);
|
||||
auto *const batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
||||
auto *const sampler = reinterpret_cast<llama_sampler *>(sampler_pointer);
|
||||
auto *const sampler = reinterpret_cast<common_sampler *>(sampler_pointer);
|
||||
const auto *const model = llama_get_model(context);
|
||||
const auto *const vocab = llama_model_get_vocab(model);
|
||||
|
||||
|
|
@ -419,7 +422,8 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
|||
if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
|
||||
|
||||
// sample the most likely token
|
||||
const auto new_token_id = llama_sampler_sample(sampler, context, -1);
|
||||
const auto new_token_id = common_sampler_sample(sampler, context, -1);
|
||||
common_sampler_accept(sampler, new_token_id, /* accept_grammar= */ true);
|
||||
|
||||
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
|
||||
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class LLamaAndroid {
|
|||
private external fun backend_free()
|
||||
private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
|
||||
private external fun free_batch(batch: Long)
|
||||
private external fun new_sampler(): Long
|
||||
private external fun new_sampler(model: Long): Long
|
||||
private external fun free_sampler(sampler: Long)
|
||||
private external fun bench_model(
|
||||
context: Long,
|
||||
|
|
@ -103,7 +103,7 @@ class LLamaAndroid {
|
|||
val batch = new_batch(DEFAULT_BATCH_SIZE, 0, 1)
|
||||
if (batch == 0L) throw IllegalStateException("new_batch() failed")
|
||||
|
||||
val sampler = new_sampler()
|
||||
val sampler = new_sampler(model)
|
||||
if (sampler == 0L) throw IllegalStateException("new_sampler() failed")
|
||||
|
||||
Log.i(tag, "Loaded model $pathToModel")
|
||||
|
|
@ -133,7 +133,7 @@ class LLamaAndroid {
|
|||
emit(str)
|
||||
}
|
||||
|
||||
kv_cache_clear(state.context)
|
||||
// kv_cache_clear(state.context)
|
||||
}
|
||||
else -> {}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue