Tidy & clean LLamaAndroid binding
This commit is contained in:
parent
8cf6b42d46
commit
1f255d4bca
|
|
@ -1,46 +1,31 @@
|
||||||
#include <android/log.h>
|
#include <android/log.h>
|
||||||
#include <jni.h>
|
#include <jni.h>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <math.h>
|
#include <cmath>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
|
||||||
// Write C++ code here.
|
|
||||||
//
|
|
||||||
// Do not forget to dynamically load the C++ library into your application.
|
|
||||||
//
|
|
||||||
// For instance,
|
|
||||||
//
|
|
||||||
// In MainActivity.java:
|
|
||||||
// static {
|
|
||||||
// System.loadLibrary("llama-android");
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// Or, in MainActivity.kt:
|
|
||||||
// companion object {
|
|
||||||
// init {
|
|
||||||
// System.loadLibrary("llama-android")
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
#define TAG "llama-android.cpp"
|
#define TAG "llama-android.cpp"
|
||||||
#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
|
#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
|
||||||
#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
|
#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
|
||||||
|
|
||||||
|
constexpr int CONTEXT_SIZE = 2048;
|
||||||
|
constexpr int N_THREADS_MIN = 1;
|
||||||
|
constexpr int N_THREADS_MAX = 8;
|
||||||
|
constexpr int N_THREADS_HEADROOM = 2;
|
||||||
|
|
||||||
jclass la_int_var;
|
jclass la_int_var;
|
||||||
jmethodID la_int_var_value;
|
jmethodID la_int_var_value;
|
||||||
jmethodID la_int_var_inc;
|
jmethodID la_int_var_inc;
|
||||||
|
|
||||||
std::string cached_token_chars;
|
std::string cached_token_chars;
|
||||||
|
|
||||||
bool is_valid_utf8(const char * string) {
|
bool is_valid_utf8(const char *string) {
|
||||||
if (!string) {
|
if (!string) { return true; }
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
const unsigned char * bytes = (const unsigned char *)string;
|
const auto *bytes = (const unsigned char *) string;
|
||||||
int num;
|
int num;
|
||||||
|
|
||||||
while (*bytes != 0x00) {
|
while (*bytes != 0x00) {
|
||||||
|
|
@ -72,11 +57,26 @@ bool is_valid_utf8(const char * string) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void log_callback(ggml_log_level level, const char * fmt, void * data) {
|
static void log_callback(ggml_log_level level, const char *fmt, void *data) {
|
||||||
if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
|
int priority;
|
||||||
else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
|
switch (level) {
|
||||||
else if (level == GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
|
case GGML_LOG_LEVEL_ERROR:
|
||||||
else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
|
priority = ANDROID_LOG_ERROR;
|
||||||
|
break;
|
||||||
|
case GGML_LOG_LEVEL_WARN:
|
||||||
|
priority = GGML_LOG_LEVEL_WARN;
|
||||||
|
break;
|
||||||
|
case GGML_LOG_LEVEL_INFO:
|
||||||
|
priority = GGML_LOG_LEVEL_INFO;
|
||||||
|
break;
|
||||||
|
case GGML_LOG_LEVEL_DEBUG:
|
||||||
|
priority = GGML_LOG_LEVEL_DEBUG;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
priority = ANDROID_LOG_DEFAULT;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
__android_log_print(priority, TAG, fmt, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
|
|
@ -116,16 +116,18 @@ Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmo
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int n_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2));
|
int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX,
|
||||||
|
(int) sysconf(_SC_NPROCESSORS_ONLN) -
|
||||||
|
N_THREADS_HEADROOM));
|
||||||
LOGi("Using %d threads", n_threads);
|
LOGi("Using %d threads", n_threads);
|
||||||
|
|
||||||
llama_context_params ctx_params = llama_context_default_params();
|
llama_context_params ctx_params = llama_context_default_params();
|
||||||
|
|
||||||
ctx_params.n_ctx = 2048;
|
ctx_params.n_ctx = CONTEXT_SIZE;
|
||||||
ctx_params.n_threads = n_threads;
|
ctx_params.n_threads = n_threads;
|
||||||
ctx_params.n_threads_batch = n_threads;
|
ctx_params.n_threads_batch = n_threads;
|
||||||
|
|
||||||
llama_context * context = llama_new_context_with_model(model, ctx_params);
|
llama_context *context = llama_init_from_model(model, ctx_params);
|
||||||
|
|
||||||
if (!context) {
|
if (!context) {
|
||||||
LOGe("llama_new_context_with_model() returned null)");
|
LOGe("llama_new_context_with_model() returned null)");
|
||||||
|
|
@ -152,7 +154,7 @@ Java_android_llama_cpp_LLamaAndroid_backend_1free(JNIEnv *, jobject) {
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT void JNICALL
|
JNIEXPORT void JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv *, jobject) {
|
Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv *, jobject) {
|
||||||
llama_log_set(log_callback, NULL);
|
llama_log_set(log_callback, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
|
|
@ -167,17 +169,17 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
||||||
jint tg,
|
jint tg,
|
||||||
jint pl,
|
jint pl,
|
||||||
jint nr
|
jint nr
|
||||||
) {
|
) {
|
||||||
auto pp_avg = 0.0;
|
auto pp_avg = 0.0;
|
||||||
auto tg_avg = 0.0;
|
auto tg_avg = 0.0;
|
||||||
auto pp_std = 0.0;
|
auto pp_std = 0.0;
|
||||||
auto tg_std = 0.0;
|
auto tg_std = 0.0;
|
||||||
|
|
||||||
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
auto *const context = reinterpret_cast<llama_context *>(context_pointer);
|
||||||
const auto model = reinterpret_cast<llama_model *>(model_pointer);
|
auto *const model = reinterpret_cast<llama_model *>(model_pointer);
|
||||||
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
auto *const batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
||||||
|
|
||||||
const int n_ctx = llama_n_ctx(context);
|
const uint32_t n_ctx = llama_n_ctx(context);
|
||||||
|
|
||||||
LOGi("n_ctx = %d", n_ctx);
|
LOGi("n_ctx = %d", n_ctx);
|
||||||
|
|
||||||
|
|
@ -190,7 +192,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
||||||
|
|
||||||
const int n_tokens = pp;
|
const int n_tokens = pp;
|
||||||
for (i = 0; i < n_tokens; i++) {
|
for (i = 0; i < n_tokens; i++) {
|
||||||
common_batch_add(*batch, 0, i, { 0 }, false);
|
common_batch_add(*batch, 0, i, {0}, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
batch->logits[batch->n_tokens - 1] = true;
|
batch->logits[batch->n_tokens - 1] = true;
|
||||||
|
|
@ -212,7 +214,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
||||||
|
|
||||||
common_batch_clear(*batch);
|
common_batch_clear(*batch);
|
||||||
for (j = 0; j < pl; j++) {
|
for (j = 0; j < pl; j++) {
|
||||||
common_batch_add(*batch, 0, i, { j }, true);
|
common_batch_add(*batch, 0, i, {j}, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
LOGi("llama_decode() text generation: %d", i);
|
LOGi("llama_decode() text generation: %d", i);
|
||||||
|
|
@ -257,25 +259,27 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
||||||
const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0;
|
const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0;
|
||||||
const auto model_n_params = double(llama_model_n_params(model)) / 1e9;
|
const auto model_n_params = double(llama_model_n_params(model)) / 1e9;
|
||||||
|
|
||||||
const auto backend = "(Android)"; // TODO: What should this be?
|
const auto *const backend = "(Android)"; // TODO: What should this be?
|
||||||
|
|
||||||
std::stringstream result;
|
std::stringstream result;
|
||||||
result << std::setprecision(2);
|
result << std::setprecision(2);
|
||||||
result << "| model | size | params | backend | test | t/s |\n";
|
result << "| model | size | params | backend | test | t/s |\n";
|
||||||
result << "| --- | --- | --- | --- | --- | --- |\n";
|
result << "| --- | --- | --- | --- | --- | --- |\n";
|
||||||
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
|
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
|
||||||
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
|
<< backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
|
||||||
|
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
|
||||||
|
<< backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
|
||||||
|
|
||||||
return env->NewStringUTF(result.str().c_str());
|
return env->NewStringUTF(result.str().c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT jlong JNICALL
|
JNIEXPORT jlong JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
|
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd,
|
||||||
|
jint n_seq_max) {
|
||||||
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
|
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
|
||||||
|
|
||||||
llama_batch *batch = new llama_batch {
|
auto *batch = new llama_batch{
|
||||||
0,
|
0,
|
||||||
nullptr,
|
nullptr,
|
||||||
nullptr,
|
nullptr,
|
||||||
|
|
@ -305,8 +309,8 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT void JNICALL
|
JNIEXPORT void JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
|
Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
|
||||||
//llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
|
//llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer)); // TODO: what is this?
|
||||||
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
auto *const batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
||||||
delete batch;
|
delete batch;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -315,7 +319,7 @@ JNIEXPORT jlong JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject) {
|
Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject) {
|
||||||
auto sparams = llama_sampler_chain_default_params();
|
auto sparams = llama_sampler_chain_default_params();
|
||||||
sparams.no_perf = true;
|
sparams.no_perf = true;
|
||||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
llama_sampler *smpl = llama_sampler_chain_init(sparams);
|
||||||
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
|
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
|
||||||
|
|
||||||
return reinterpret_cast<jlong>(smpl);
|
return reinterpret_cast<jlong>(smpl);
|
||||||
|
|
@ -324,7 +328,9 @@ Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject) {
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT void JNICALL
|
JNIEXPORT void JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_free_1sampler(JNIEnv *, jobject, jlong sampler_pointer) {
|
Java_android_llama_cpp_LLamaAndroid_free_1sampler(JNIEnv *, jobject, jlong sampler_pointer) {
|
||||||
llama_sampler_free(reinterpret_cast<llama_sampler *>(sampler_pointer));
|
// Properly cast from jlong to pointer type
|
||||||
|
auto* sampler = (llama_sampler*)(void*)(sampler_pointer);
|
||||||
|
llama_sampler_free(sampler);
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
|
|
@ -349,13 +355,12 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
|
||||||
jstring jtext,
|
jstring jtext,
|
||||||
jboolean format_chat,
|
jboolean format_chat,
|
||||||
jint n_len
|
jint n_len
|
||||||
) {
|
) {
|
||||||
|
|
||||||
cached_token_chars.clear();
|
cached_token_chars.clear();
|
||||||
|
|
||||||
const auto text = env->GetStringUTFChars(jtext, 0);
|
const auto *const text = env->GetStringUTFChars(jtext, 0);
|
||||||
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
auto *const context = reinterpret_cast<llama_context *>(context_pointer);
|
||||||
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
auto *const batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
||||||
|
|
||||||
bool parse_special = (format_chat == JNI_TRUE);
|
bool parse_special = (format_chat == JNI_TRUE);
|
||||||
const auto tokens_list = common_tokenize(context, text, true, parse_special);
|
const auto tokens_list = common_tokenize(context, text, true, parse_special);
|
||||||
|
|
@ -369,7 +374,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
|
||||||
LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough");
|
LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough");
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto id : tokens_list) {
|
for (auto id: tokens_list) {
|
||||||
LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id);
|
LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -377,7 +382,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
|
||||||
|
|
||||||
// evaluate the initial prompt
|
// evaluate the initial prompt
|
||||||
for (auto i = 0; i < tokens_list.size(); i++) {
|
for (auto i = 0; i < tokens_list.size(); i++) {
|
||||||
common_batch_add(*batch, tokens_list[i], i, { 0 }, false);
|
common_batch_add(*batch, tokens_list[i], i, {0}, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// llama_decode will output logits only for the last token of the prompt
|
// llama_decode will output logits only for the last token of the prompt
|
||||||
|
|
@ -395,7 +400,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT jstring JNICALL
|
JNIEXPORT jstring JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
||||||
JNIEnv * env,
|
JNIEnv *env,
|
||||||
jobject,
|
jobject,
|
||||||
jlong context_pointer,
|
jlong context_pointer,
|
||||||
jlong batch_pointer,
|
jlong batch_pointer,
|
||||||
|
|
@ -403,11 +408,11 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
||||||
jint n_len,
|
jint n_len,
|
||||||
jobject intvar_ncur
|
jobject intvar_ncur
|
||||||
) {
|
) {
|
||||||
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
auto *const context = reinterpret_cast<llama_context *>(context_pointer);
|
||||||
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
auto *const batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
||||||
const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer);
|
auto *const sampler = reinterpret_cast<llama_sampler *>(sampler_pointer);
|
||||||
const auto model = llama_get_model(context);
|
const auto *const model = llama_get_model(context);
|
||||||
const auto vocab = llama_model_get_vocab(model);
|
const auto *const vocab = llama_model_get_vocab(model);
|
||||||
|
|
||||||
if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
|
if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
|
||||||
if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
|
if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
|
||||||
|
|
@ -427,14 +432,15 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
||||||
jstring new_token = nullptr;
|
jstring new_token = nullptr;
|
||||||
if (is_valid_utf8(cached_token_chars.c_str())) {
|
if (is_valid_utf8(cached_token_chars.c_str())) {
|
||||||
new_token = env->NewStringUTF(cached_token_chars.c_str());
|
new_token = env->NewStringUTF(cached_token_chars.c_str());
|
||||||
LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id);
|
LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(),
|
||||||
|
new_token_chars.c_str(), new_token_id);
|
||||||
cached_token_chars.clear();
|
cached_token_chars.clear();
|
||||||
} else {
|
} else {
|
||||||
new_token = env->NewStringUTF("");
|
new_token = env->NewStringUTF("");
|
||||||
}
|
}
|
||||||
|
|
||||||
common_batch_clear(*batch);
|
common_batch_clear(*batch);
|
||||||
common_batch_add(*batch, new_token_id, n_cur, { 0 }, true);
|
common_batch_add(*batch, new_token_id, n_cur, {0}, true);
|
||||||
|
|
||||||
env->CallVoidMethod(intvar_ncur, la_int_var_inc);
|
env->CallVoidMethod(intvar_ncur, la_int_var_inc);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,8 +36,6 @@ class LLamaAndroid {
|
||||||
}
|
}
|
||||||
}.asCoroutineDispatcher()
|
}.asCoroutineDispatcher()
|
||||||
|
|
||||||
private val nlen: Int = 64
|
|
||||||
|
|
||||||
private external fun log_to_android()
|
private external fun log_to_android()
|
||||||
private external fun load_model(filename: String): Long
|
private external fun load_model(filename: String): Long
|
||||||
private external fun free_model(model: Long)
|
private external fun free_model(model: Long)
|
||||||
|
|
@ -102,7 +100,7 @@ class LLamaAndroid {
|
||||||
val context = new_context(model)
|
val context = new_context(model)
|
||||||
if (context == 0L) throw IllegalStateException("new_context() failed")
|
if (context == 0L) throw IllegalStateException("new_context() failed")
|
||||||
|
|
||||||
val batch = new_batch(512, 0, 1)
|
val batch = new_batch(DEFAULT_BATCH_SIZE, 0, 1)
|
||||||
if (batch == 0L) throw IllegalStateException("new_batch() failed")
|
if (batch == 0L) throw IllegalStateException("new_batch() failed")
|
||||||
|
|
||||||
val sampler = new_sampler()
|
val sampler = new_sampler()
|
||||||
|
|
@ -116,17 +114,25 @@ class LLamaAndroid {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun send(message: String, formatChat: Boolean = false): Flow<String> = flow {
|
fun send(
|
||||||
|
message: String,
|
||||||
|
formatChat: Boolean = false,
|
||||||
|
predictLength: Int = DEFAULT_PREDICT_LENGTH,
|
||||||
|
): Flow<String> = flow {
|
||||||
when (val state = threadLocalState.get()) {
|
when (val state = threadLocalState.get()) {
|
||||||
is State.Loaded -> {
|
is State.Loaded -> {
|
||||||
val ncur = IntVar(completion_init(state.context, state.batch, message, formatChat, nlen))
|
val nCur = IntVar(
|
||||||
while (ncur.value <= nlen) {
|
completion_init(state.context, state.batch, message, formatChat, predictLength)
|
||||||
val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur)
|
)
|
||||||
if (str == null) {
|
|
||||||
break
|
while (nCur.value <= predictLength) {
|
||||||
}
|
val str = completion_loop(
|
||||||
|
state.context, state.batch, state.sampler, predictLength, nCur
|
||||||
|
) ?: break
|
||||||
|
|
||||||
emit(str)
|
emit(str)
|
||||||
}
|
}
|
||||||
|
|
||||||
kv_cache_clear(state.context)
|
kv_cache_clear(state.context)
|
||||||
}
|
}
|
||||||
else -> {}
|
else -> {}
|
||||||
|
|
@ -155,6 +161,9 @@ class LLamaAndroid {
|
||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
|
private const val DEFAULT_BATCH_SIZE = 512
|
||||||
|
private const val DEFAULT_PREDICT_LENGTH = 128
|
||||||
|
|
||||||
private class IntVar(value: Int) {
|
private class IntVar(value: Int) {
|
||||||
@Volatile
|
@Volatile
|
||||||
var value: Int = value
|
var value: Int = value
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue