Polish: better logging & documentation
This commit is contained in:
parent
ec502cfde9
commit
2b52563737
|
|
@ -37,10 +37,10 @@ constexpr int N_THREADS_MIN = 1;
|
||||||
constexpr int N_THREADS_MAX = 8;
|
constexpr int N_THREADS_MAX = 8;
|
||||||
constexpr int N_THREADS_HEADROOM = 2;
|
constexpr int N_THREADS_HEADROOM = 2;
|
||||||
|
|
||||||
constexpr int CONTEXT_SIZE = 4096;
|
constexpr int DEFAULT_CONTEXT_SIZE = 8192;
|
||||||
constexpr int OVERFLOW_HEADROOM = 4;
|
constexpr int OVERFLOW_HEADROOM = 4;
|
||||||
constexpr int BATCH_SIZE = 512;
|
constexpr int BATCH_SIZE = 512;
|
||||||
constexpr float SAMPLER_TEMP = 0.3f;
|
constexpr float DEFAULT_SAMPLER_TEMP = 0.3f;
|
||||||
|
|
||||||
static llama_model * g_model;
|
static llama_model * g_model;
|
||||||
static llama_context * g_context;
|
static llama_context * g_context;
|
||||||
|
|
@ -98,12 +98,11 @@ Java_android_llama_cpp_LLamaAndroid_loadModel(JNIEnv *env, jobject, jstring file
|
||||||
llama_model_params model_params = llama_model_default_params();
|
llama_model_params model_params = llama_model_default_params();
|
||||||
|
|
||||||
const auto *path_to_model = env->GetStringUTFChars(filename, 0);
|
const auto *path_to_model = env->GetStringUTFChars(filename, 0);
|
||||||
LOGd("Loading model from: %s", path_to_model);
|
LOGd("%s: Loading model from: \n%s\n", __func__, path_to_model);
|
||||||
|
|
||||||
auto *model = llama_model_load_from_file(path_to_model, model_params);
|
auto *model = llama_model_load_from_file(path_to_model, model_params);
|
||||||
env->ReleaseStringUTFChars(filename, path_to_model);
|
env->ReleaseStringUTFChars(filename, path_to_model);
|
||||||
if (!model) {
|
if (!model) {
|
||||||
LOGe("load_model() failed");
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
g_model = model;
|
g_model = model;
|
||||||
|
|
@ -112,31 +111,36 @@ Java_android_llama_cpp_LLamaAndroid_loadModel(JNIEnv *env, jobject, jstring file
|
||||||
|
|
||||||
static llama_context *init_context(llama_model *model) {
|
static llama_context *init_context(llama_model *model) {
|
||||||
if (!model) {
|
if (!model) {
|
||||||
LOGe("init_context(): model cannot be null");
|
LOGe("%s: model cannot be null", __func__);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Multi-threading setup
|
// Multi-threading setup
|
||||||
int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX,
|
const int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX,
|
||||||
(int) sysconf(_SC_NPROCESSORS_ONLN) -
|
(int) sysconf(_SC_NPROCESSORS_ONLN) -
|
||||||
N_THREADS_HEADROOM));
|
N_THREADS_HEADROOM));
|
||||||
LOGi("Using %d threads", n_threads);
|
LOGi("%s: Using %d threads", __func__, n_threads);
|
||||||
|
|
||||||
// Context parameters setup
|
// Context parameters setup
|
||||||
llama_context_params ctx_params = llama_context_default_params();
|
llama_context_params ctx_params = llama_context_default_params();
|
||||||
ctx_params.n_ctx = CONTEXT_SIZE;
|
const int trained_context_size = llama_model_n_ctx_train(model);
|
||||||
|
if (DEFAULT_CONTEXT_SIZE > trained_context_size) {
|
||||||
|
LOGe("%s: Model was trained with only %d context size! Enforcing %d context size...",
|
||||||
|
__func__, trained_context_size, DEFAULT_CONTEXT_SIZE);
|
||||||
|
}
|
||||||
|
ctx_params.n_ctx = DEFAULT_CONTEXT_SIZE;
|
||||||
ctx_params.n_batch = BATCH_SIZE;
|
ctx_params.n_batch = BATCH_SIZE;
|
||||||
ctx_params.n_ubatch = BATCH_SIZE;
|
ctx_params.n_ubatch = BATCH_SIZE;
|
||||||
ctx_params.n_threads = n_threads;
|
ctx_params.n_threads = n_threads;
|
||||||
ctx_params.n_threads_batch = n_threads;
|
ctx_params.n_threads_batch = n_threads;
|
||||||
auto *context = llama_init_from_model(g_model, ctx_params);
|
auto *context = llama_init_from_model(g_model, ctx_params);
|
||||||
if (context == nullptr) {
|
if (context == nullptr) {
|
||||||
LOGe("llama_new_context_with_model() returned null)");
|
LOGe("%s: llama_new_context_with_model() returned null)", __func__);
|
||||||
}
|
}
|
||||||
return context;
|
return context;
|
||||||
}
|
}
|
||||||
|
|
||||||
static llama_batch *new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) {
|
static llama_batch *new_batch(int n_tokens, int n_seq_max = 1) {
|
||||||
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
|
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
|
||||||
auto *batch = new llama_batch{
|
auto *batch = new llama_batch{
|
||||||
0,
|
0,
|
||||||
|
|
@ -148,12 +152,7 @@ static llama_batch *new_batch(int n_tokens, bool embd = false, int n_seq_max = 1
|
||||||
nullptr,
|
nullptr,
|
||||||
};
|
};
|
||||||
|
|
||||||
if (embd) {
|
|
||||||
batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd);
|
|
||||||
} else {
|
|
||||||
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
|
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
|
||||||
}
|
|
||||||
|
|
||||||
batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
|
batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
|
||||||
batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
|
batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
|
||||||
batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
|
batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
|
||||||
|
|
@ -177,7 +176,7 @@ Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unus
|
||||||
if (!context) { return 1; }
|
if (!context) { return 1; }
|
||||||
g_context = context;
|
g_context = context;
|
||||||
g_batch = new_batch(BATCH_SIZE);
|
g_batch = new_batch(BATCH_SIZE);
|
||||||
g_sampler = new_sampler(SAMPLER_TEMP);
|
g_sampler = new_sampler(DEFAULT_SAMPLER_TEMP);
|
||||||
g_chat_templates = common_chat_templates_init(g_model, "");
|
g_chat_templates = common_chat_templates_init(g_model, "");
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
@ -305,7 +304,9 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Prediction loop's long-term states
|
* Completion loop's long-term states:
|
||||||
|
* - chat management
|
||||||
|
* - position tracking
|
||||||
*/
|
*/
|
||||||
constexpr const char *ROLE_SYSTEM = "system";
|
constexpr const char *ROLE_SYSTEM = "system";
|
||||||
constexpr const char *ROLE_USER = "user";
|
constexpr const char *ROLE_USER = "user";
|
||||||
|
|
@ -325,6 +326,8 @@ static void reset_long_term_states(const bool clear_kv_cache = true) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
* TODO-hyin: implement sliding-window version as a better alternative
|
||||||
|
*
|
||||||
* Context shifting by discarding the older half of the tokens appended after system prompt:
|
* Context shifting by discarding the older half of the tokens appended after system prompt:
|
||||||
* - take the [system_prompt_position] first tokens from the original prompt
|
* - take the [system_prompt_position] first tokens from the original prompt
|
||||||
* - take half of the last (system_prompt_position - system_prompt_position) tokens
|
* - take half of the last (system_prompt_position - system_prompt_position) tokens
|
||||||
|
|
@ -332,12 +335,11 @@ static void reset_long_term_states(const bool clear_kv_cache = true) {
|
||||||
*/
|
*/
|
||||||
static void shift_context() {
|
static void shift_context() {
|
||||||
const int n_discard = (current_position - system_prompt_position) / 2;
|
const int n_discard = (current_position - system_prompt_position) / 2;
|
||||||
LOGi("Discarding %d tokens", n_discard);
|
LOGi("%s: Discarding %d tokens", __func__, n_discard);
|
||||||
|
|
||||||
llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard);
|
llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard);
|
||||||
llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard);
|
llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard);
|
||||||
current_position -= n_discard;
|
current_position -= n_discard;
|
||||||
LOGi("Context shifting done! Current position: %d", current_position);
|
LOGi("%s: Context shifting done! Current position: %d", __func__, current_position);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string chat_add_and_format(const std::string &role, const std::string &content) {
|
static std::string chat_add_and_format(const std::string &role, const std::string &content) {
|
||||||
|
|
@ -347,10 +349,26 @@ static std::string chat_add_and_format(const std::string &role, const std::strin
|
||||||
auto formatted = common_chat_format_single(
|
auto formatted = common_chat_format_single(
|
||||||
g_chat_templates.get(), chat_msgs, new_msg, role == ROLE_USER, /* use_jinja */ false);
|
g_chat_templates.get(), chat_msgs, new_msg, role == ROLE_USER, /* use_jinja */ false);
|
||||||
chat_msgs.push_back(new_msg);
|
chat_msgs.push_back(new_msg);
|
||||||
LOGi("Formatted and added %s message: \n%s\n", role.c_str(), formatted.c_str());
|
LOGi("%s: Formatted and added %s message: \n%s\n", __func__, role.c_str(), formatted.c_str());
|
||||||
return formatted;
|
return formatted;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Completion loop's short-term states:
|
||||||
|
* - stop generation position
|
||||||
|
* - token chars caching
|
||||||
|
* - current assistant message being generated
|
||||||
|
*/
|
||||||
|
static llama_pos stop_generation_position;
|
||||||
|
static std::string cached_token_chars;
|
||||||
|
static std::ostringstream assistant_ss;
|
||||||
|
|
||||||
|
static void reset_short_term_states() {
|
||||||
|
stop_generation_position = 0;
|
||||||
|
cached_token_chars.clear();
|
||||||
|
assistant_ss.str("");
|
||||||
|
}
|
||||||
|
|
||||||
static int decode_tokens_in_batches(
|
static int decode_tokens_in_batches(
|
||||||
llama_context *context,
|
llama_context *context,
|
||||||
const llama_tokens &tokens,
|
const llama_tokens &tokens,
|
||||||
|
|
@ -358,15 +376,15 @@ static int decode_tokens_in_batches(
|
||||||
bool compute_last_logit = false,
|
bool compute_last_logit = false,
|
||||||
llama_batch *batch = g_batch) {
|
llama_batch *batch = g_batch) {
|
||||||
// Process tokens in batches using the global batch
|
// Process tokens in batches using the global batch
|
||||||
LOGd("Decode %d tokens starting at position %d", (int) tokens.size(), start_pos);
|
LOGd("%s: Decode %d tokens starting at position %d", __func__, (int) tokens.size(), start_pos);
|
||||||
for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
|
for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
|
||||||
const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
|
const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
|
||||||
common_batch_clear(*batch);
|
common_batch_clear(*batch);
|
||||||
LOGv("Preparing a batch size of %d starting at: %d", cur_batch_size, i);
|
LOGv("%s: Preparing a batch size of %d starting at: %d", __func__, cur_batch_size, i);
|
||||||
|
|
||||||
// Shift context if current batch cannot fit into the context
|
// Shift context if current batch cannot fit into the context
|
||||||
if (start_pos + i + cur_batch_size >= CONTEXT_SIZE - OVERFLOW_HEADROOM) {
|
if (start_pos + i + cur_batch_size >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
|
||||||
LOGw("Current batch won't fit into context! Shifting...");
|
LOGw("%s: Current batch won't fit into context! Shifting...", __func__);
|
||||||
shift_context();
|
shift_context();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -381,26 +399,13 @@ static int decode_tokens_in_batches(
|
||||||
// Decode this batch
|
// Decode this batch
|
||||||
const int decode_result = llama_decode(context, *batch);
|
const int decode_result = llama_decode(context, *batch);
|
||||||
if (decode_result) {
|
if (decode_result) {
|
||||||
LOGe("llama_decode failed w/ %d", decode_result);
|
LOGe("%s: llama_decode failed w/ %d", __func__, decode_result);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Prediction loop's short-term states
|
|
||||||
*/
|
|
||||||
static llama_pos stop_completion_position;
|
|
||||||
static std::string cached_token_chars;
|
|
||||||
static std::ostringstream assistant_ss; // For storing current assistant message
|
|
||||||
|
|
||||||
static void reset_short_term_states() {
|
|
||||||
stop_completion_position = 0;
|
|
||||||
cached_token_chars.clear();
|
|
||||||
assistant_ss.str("");
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT jint JNICALL
|
JNIEXPORT jint JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
||||||
|
|
@ -414,7 +419,7 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
||||||
|
|
||||||
// Obtain system prompt from JEnv
|
// Obtain system prompt from JEnv
|
||||||
const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr);
|
const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr);
|
||||||
LOGd("System prompt received: \n%s", system_prompt);
|
LOGd("%s: System prompt received: \n%s", __func__, system_prompt);
|
||||||
std::string formatted_system_prompt(system_prompt);
|
std::string formatted_system_prompt(system_prompt);
|
||||||
env->ReleaseStringUTFChars(jsystem_prompt, system_prompt);
|
env->ReleaseStringUTFChars(jsystem_prompt, system_prompt);
|
||||||
|
|
||||||
|
|
@ -432,16 +437,16 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle context overflow
|
// Handle context overflow
|
||||||
const int max_batch_size = CONTEXT_SIZE - OVERFLOW_HEADROOM;
|
const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
|
||||||
if ((int) system_tokens.size() > max_batch_size) {
|
if ((int) system_tokens.size() > max_batch_size) {
|
||||||
LOGe("System prompt too long for context! %d tokens, max: %d",
|
LOGe("%s: System prompt too long for context! %d tokens, max: %d",
|
||||||
(int) system_tokens.size(), max_batch_size);
|
__func__, (int) system_tokens.size(), max_batch_size);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode system tokens in batches
|
// Decode system tokens in batches
|
||||||
if (decode_tokens_in_batches(g_context, system_tokens, current_position)) {
|
if (decode_tokens_in_batches(g_context, system_tokens, current_position)) {
|
||||||
LOGe("llama_decode() failed!");
|
LOGe("%s: llama_decode() failed!", __func__);
|
||||||
return 2;
|
return 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -463,7 +468,7 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
|
||||||
|
|
||||||
// Obtain and tokenize user prompt
|
// Obtain and tokenize user prompt
|
||||||
const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr);
|
const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr);
|
||||||
LOGd("User prompt received: \n%s", user_prompt);
|
LOGd("%s: User prompt received: \n%s", __func__, user_prompt);
|
||||||
std::string formatted_user_prompt(user_prompt);
|
std::string formatted_user_prompt(user_prompt);
|
||||||
env->ReleaseStringUTFChars(juser_prompt, user_prompt);
|
env->ReleaseStringUTFChars(juser_prompt, user_prompt);
|
||||||
|
|
||||||
|
|
@ -481,22 +486,22 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
|
||||||
|
|
||||||
// Ensure user prompt doesn't exceed the context size by truncating if necessary.
|
// Ensure user prompt doesn't exceed the context size by truncating if necessary.
|
||||||
const int user_prompt_size = (int) user_tokens.size();
|
const int user_prompt_size = (int) user_tokens.size();
|
||||||
const int max_batch_size = CONTEXT_SIZE - OVERFLOW_HEADROOM;
|
const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
|
||||||
if (user_prompt_size > max_batch_size) {
|
if (user_prompt_size > max_batch_size) {
|
||||||
const int skipped_tokens = user_prompt_size - max_batch_size;
|
const int skipped_tokens = user_prompt_size - max_batch_size;
|
||||||
user_tokens.resize(max_batch_size);
|
user_tokens.resize(max_batch_size);
|
||||||
LOGw("User prompt too long! Skipped %d tokens!", skipped_tokens);
|
LOGw("%s: User prompt too long! Skipped %d tokens!", __func__, skipped_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode user tokens in batches
|
// Decode user tokens in batches
|
||||||
if (decode_tokens_in_batches(g_context, user_tokens, current_position, true)) {
|
if (decode_tokens_in_batches(g_context, user_tokens, current_position, true)) {
|
||||||
LOGe("llama_decode() failed!");
|
LOGe("%s: llama_decode() failed!", __func__);
|
||||||
return 2;
|
return 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update position
|
// Update position
|
||||||
current_position += user_prompt_size;
|
current_position += user_prompt_size;
|
||||||
stop_completion_position = current_position + user_prompt_size + n_predict;
|
stop_generation_position = current_position + user_prompt_size + n_predict;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -536,19 +541,19 @@ static bool is_valid_utf8(const char *string) {
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT jstring JNICALL
|
JNIEXPORT jstring JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_completionLoop(
|
Java_android_llama_cpp_LLamaAndroid_generateNextToken(
|
||||||
JNIEnv *env,
|
JNIEnv *env,
|
||||||
jobject /*unused*/
|
jobject /*unused*/
|
||||||
) {
|
) {
|
||||||
// Infinite text generation via context shifting
|
// Infinite text generation via context shifting
|
||||||
if (current_position >= CONTEXT_SIZE - OVERFLOW_HEADROOM) {
|
if (current_position >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
|
||||||
LOGw("Context full! Shifting...");
|
LOGw("%s: Context full! Shifting...", __func__);
|
||||||
shift_context();
|
shift_context();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop if reaching the marked position
|
// Stop if reaching the marked position
|
||||||
if (current_position >= stop_completion_position) {
|
if (current_position >= stop_generation_position) {
|
||||||
LOGw("STOP: hitting stop position: %d", stop_completion_position);
|
LOGw("%s: STOP: hitting stop position: %d", __func__, stop_generation_position);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -560,7 +565,7 @@ Java_android_llama_cpp_LLamaAndroid_completionLoop(
|
||||||
common_batch_clear(*g_batch);
|
common_batch_clear(*g_batch);
|
||||||
common_batch_add(*g_batch, new_token_id, current_position, {0}, true);
|
common_batch_add(*g_batch, new_token_id, current_position, {0}, true);
|
||||||
if (llama_decode(g_context, *g_batch) != 0) {
|
if (llama_decode(g_context, *g_batch) != 0) {
|
||||||
LOGe("llama_decode() failed for generated token");
|
LOGe("%s: llama_decode() failed for generated token", __func__);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -578,7 +583,7 @@ Java_android_llama_cpp_LLamaAndroid_completionLoop(
|
||||||
auto new_token_chars = common_token_to_piece(g_context, new_token_id);
|
auto new_token_chars = common_token_to_piece(g_context, new_token_id);
|
||||||
cached_token_chars += new_token_chars;
|
cached_token_chars += new_token_chars;
|
||||||
|
|
||||||
// Create and return Java string
|
// Create and return a valid UTF-8 Java string
|
||||||
jstring result = nullptr;
|
jstring result = nullptr;
|
||||||
if (is_valid_utf8(cached_token_chars.c_str())) {
|
if (is_valid_utf8(cached_token_chars.c_str())) {
|
||||||
result = env->NewStringUTF(cached_token_chars.c_str());
|
result = env->NewStringUTF(cached_token_chars.c_str());
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ class LLamaAndroid {
|
||||||
|
|
||||||
private external fun processSystemPrompt(systemPrompt: String): Int
|
private external fun processSystemPrompt(systemPrompt: String): Int
|
||||||
private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int
|
private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int
|
||||||
private external fun completionLoop(): String?
|
private external fun generateNextToken(): String?
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Thread local state
|
* Thread local state
|
||||||
|
|
@ -128,7 +128,7 @@ class LLamaAndroid {
|
||||||
|
|
||||||
Log.i(TAG, "User prompt processed! Generating assistant prompt...")
|
Log.i(TAG, "User prompt processed! Generating assistant prompt...")
|
||||||
while (true) {
|
while (true) {
|
||||||
completionLoop()?.let { utf8token ->
|
generateNextToken()?.let { utf8token ->
|
||||||
if (utf8token.isNotEmpty()) emit(utf8token)
|
if (utf8token.isNotEmpty()) emit(utf8token)
|
||||||
} ?: break
|
} ?: break
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue