diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9582164b58..62273b2c37 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -209,6 +209,7 @@ llama_build_and_test( peg-parser/tests.h ) llama_build_and_test(test-regex-partial.cpp) +llama_build_and_test(test-save-load-state.cpp) if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") set(MODEL_NAME "tinyllamas/stories15M-q4_0.gguf") diff --git a/tests/test-save-load-state.cpp b/tests/test-save-load-state.cpp new file mode 100644 index 0000000000..8a1b849088 --- /dev/null +++ b/tests/test-save-load-state.cpp @@ -0,0 +1,327 @@ +#include "common.h" +#include "log.h" +#include "ggml-backend.h" +#include "ggml.h" +#include "gguf.h" +#include "ggml-cpp.h" +#include "llama.h" +#include "llama-cpp.h" +#include "../src/llama-arch.h" +#include "../src/llama-model-saver.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Taken from test-llama-arch.cpp +static void set_tensor_data(struct ggml_tensor * tensor, void * userdata) { + std::hash hasher; + std::mt19937 gen(hasher(tensor->name) + *(const size_t *) userdata); + std::normal_distribution dis(0.0f, 1.0e-2f); + + const int64_t ne = ggml_nelements(tensor); + if (tensor->type == GGML_TYPE_F32) { + std::vector tmp(ne); + for (int64_t i = 0; i < ne; i++) { + tmp[i] = dis(gen); + } + ggml_backend_tensor_set(tensor, tmp.data(), 0, ggml_nbytes(tensor)); + } else if (tensor->type == GGML_TYPE_F16) { + std::vector tmp(ne); + for (int64_t i = 0; i < ne; i++) { + tmp[i] = ggml_fp32_to_fp16(dis(gen)); + } + ggml_backend_tensor_set(tensor, tmp.data(), 0, ggml_nbytes(tensor)); + } else { + GGML_ABORT("fatal error"); + } +} + +// Taken from test-llama-arch.cpp +static std::vector get_tokens(const uint32_t n_tokens, const uint32_t n_vocab, const size_t seed){ + std::mt19937 gen(seed); + std::uniform_int_distribution<> dis(0, n_vocab - 1); + std::vector ret; + ret.reserve(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + ret.push_back(dis(gen)); + } + return ret; +} + +// Taken from test-llama-arch.cpp +static std::pair get_model_and_ctx( + struct gguf_context * gguf_ctx, const size_t seed) { + llama_model_params model_params = llama_model_default_params(); + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_ctx = 0; // will be set from model + ctx_params.n_threads = 4; + ctx_params.n_threads_batch = 4; + + size_t tmp = seed; + llama_model_ptr model(llama_model_init_from_user(gguf_ctx, set_tensor_data, &tmp, model_params)); + if (!model) { + throw std::runtime_error("failed to create llama model"); + } + llama_context_ptr lctx(llama_init_from_model(model.get(), ctx_params)); + if (!lctx) { + throw std::runtime_error("failed to create llama context"); + } + return std::make_pair(std::move(model), std::move(lctx)); +} + +static bool compare_logits(const std::vector & a, const std::vector & b) { + if (a.size() != b.size()) { + return false; + } + const float threshold = 2e-5f; // Relaxed threshold for state save/load numerical precision + for (size_t i = 0; i < a.size(); i++) { + if (std::abs(a[i] - b[i]) > threshold) { + return false; + } + } + return true; +} + +static std::vector get_logits_from_context(llama_context * lctx) { + const int n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(llama_get_model(lctx))); + std::vector logits; + logits.reserve(n_vocab); + const float * logits_ith = llama_get_logits_ith(lctx, -1); + for (int j = 0; j < n_vocab; j++) { + logits.push_back(logits_ith[j]); + } + return logits; +} + +static bool test_save_and_load_state(const gguf_context_ptr & gguf_ctx, int seed) { + const int key_idx = gguf_find_key(gguf_ctx.get(), "general.architecture"); + const char * arch_name = (key_idx == -1) ? "unknown" : gguf_get_val_str(gguf_ctx.get(), key_idx); + + const int vocab_key_idx = gguf_find_key(gguf_ctx.get(), "llama.vocab_size"); + const uint32_t n_vocab = (vocab_key_idx == -1) ? 128 : gguf_get_val_u32(gguf_ctx.get(), vocab_key_idx); + + const std::vector session_tokens = get_tokens(16, n_vocab, seed); + const char * session_file = "test_session.tmp"; + int n_ctx = 0; + bool ok = true; + try { + auto model_and_ctx = get_model_and_ctx(gguf_ctx.get(), seed); + llama_model * model = model_and_ctx.first.get(); + + std::vector logits1; + std::vector logits2; + // Decode a few tokens and save the session state. + { + llama_context * ctx = model_and_ctx.second.get(); + llama_batch batch = llama_batch_init(session_tokens.size(), 0, 1); + for (size_t i = 0; i < session_tokens.size(); ++i) { + common_batch_add(batch, session_tokens[i], i, {0}, i == (session_tokens.size() -1)); + } + + if (llama_decode(ctx, batch) != 0) { + throw std::runtime_error("llama_decode failed"); + } + llama_batch_free(batch); + + logits1 = get_logits_from_context(ctx); + n_ctx = llama_n_ctx(ctx); + + llama_state_save_file(ctx, session_file, session_tokens.data(), session_tokens.size()); + } + // Create a new llama_context and load and restore the session state + { + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_ctx = n_ctx; + ctx_params.n_threads = 4; + ctx_params.n_threads_batch = 4; + llama_context * ctx = llama_init_from_model(model, ctx_params); + + std::vector loaded_tokens(session_tokens.size()); + size_t n_loaded_tokens = 0; + if (llama_state_load_file(ctx, session_file, loaded_tokens.data(), loaded_tokens.size(), &n_loaded_tokens) != 1) { + throw std::runtime_error("llama_state_load_file failed"); + } + + if (n_loaded_tokens != session_tokens.size()) { + throw std::runtime_error("loaded incorrect number of tokens"); + } + + loaded_tokens.resize(n_loaded_tokens); + if (loaded_tokens != session_tokens) { + throw std::runtime_error("loaded session tokens do not match"); + } + + llama_memory_t mem = llama_get_memory(ctx); + fprintf(stderr, "Before replay: KV cache seq 0 max pos = %d\n", llama_memory_seq_pos_max(mem, 0)); + + if (!common_replay_last_token(ctx, loaded_tokens.back(), n_loaded_tokens)) { + throw std::runtime_error("failed to replay last token"); + } + + fprintf(stderr, "After replay: KV cache seq 0 max pos = %d\n", llama_memory_seq_pos_max(mem, 0)); + + logits2 = get_logits_from_context(ctx); + + // Verify we can continue decoding after load + llama_token next_token = get_tokens(1, n_vocab, seed + 100)[0]; + llama_batch batch = llama_batch_init(1, 0, 1); + common_batch_add(batch, next_token, n_loaded_tokens + 1, {0}, true); + if (llama_decode(ctx, batch) != 0) { + llama_batch_free(batch); + throw std::runtime_error("failed to decode next token after load"); + } + llama_batch_free(batch); + + llama_free(ctx); + } + // Verify the logits from the original and the restore session state. + if (!compare_logits(logits1, logits2)) { + ok = false; + } + } catch (const std::exception & e) { + fprintf(stderr, "Exception during test for %s: %s\n", arch_name, e.what()); + ok = false; + } + std::remove(session_file); + fprintf(stderr, "Test save_load_state for arch '%s': %s\n", arch_name, ok ? "PASSED" : "FAILED"); + return ok; +} + +static gguf_context_ptr transformer_model() { + const llm_arch arch = LLM_ARCH_LLAMA; + + gguf_context_ptr ret{gguf_init_empty()}; + llama_model_saver ms{arch, ret.get()}; + + const uint32_t n_ctx = 128; + const uint32_t n_vocab = 50; + const uint32_t n_embd = 256; + const uint32_t n_head = 2; + const uint32_t n_ff = 384; + const uint32_t n_layer = 2; + const uint32_t n_embd_head = n_embd / n_head; + + ms.add_kv(LLM_KV_GENERAL_ARCHITECTURE, llm_arch_name(arch)); + ms.add_kv(LLM_KV_VOCAB_SIZE, n_vocab); + ms.add_kv(LLM_KV_CONTEXT_LENGTH, n_ctx); + ms.add_kv(LLM_KV_EMBEDDING_LENGTH, n_embd); + ms.add_kv(LLM_KV_FEED_FORWARD_LENGTH, n_ff); + ms.add_kv(LLM_KV_BLOCK_COUNT, n_layer); + ms.add_kv(LLM_KV_ATTENTION_HEAD_COUNT, n_head); + ms.add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, n_head); + ms.add_kv(LLM_KV_ROPE_DIMENSION_COUNT, n_embd_head); + ms.add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, 1e-5f); + ms.add_kv(LLM_KV_TOKENIZER_MODEL, "no_vocab"); + + return ret; +} + +static gguf_context_ptr recurrent_model() { + const llm_arch arch = LLM_ARCH_MAMBA; + + gguf_context_ptr ret{gguf_init_empty()}; + llama_model_saver ms{arch, ret.get()}; + + const uint32_t n_ctx = 128; + const uint32_t n_vocab = 128; + const uint32_t n_embd = 256; + const uint32_t n_layer = 2; + + ms.add_kv(LLM_KV_GENERAL_ARCHITECTURE, llm_arch_name(arch)); + ms.add_kv(LLM_KV_VOCAB_SIZE, n_vocab); + ms.add_kv(LLM_KV_CONTEXT_LENGTH, n_ctx); + ms.add_kv(LLM_KV_EMBEDDING_LENGTH, n_embd); + ms.add_kv(LLM_KV_BLOCK_COUNT, n_layer); + ms.add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, 1e-5f); + ms.add_kv(LLM_KV_SSM_CONV_KERNEL, uint32_t(4)); + ms.add_kv(LLM_KV_SSM_INNER_SIZE, 2 * n_embd); + ms.add_kv(LLM_KV_SSM_STATE_SIZE, uint32_t(16)); + ms.add_kv(LLM_KV_SSM_TIME_STEP_RANK, n_embd / 16); + ms.add_kv(LLM_KV_TOKENIZER_MODEL, "no_vocab"); + + return ret; +} + +static gguf_context_ptr hybrid_model() { + const llm_arch arch = LLM_ARCH_JAMBA; + + gguf_context_ptr ret{gguf_init_empty()}; + llama_model_saver ms{arch, ret.get()}; + + const uint32_t n_ctx = 128; + const uint32_t n_vocab = 128; + const uint32_t n_embd = 256; + const uint32_t n_head = 2; + const uint32_t n_ff = 384; + const uint32_t n_layer = 4; + const uint32_t n_embd_head = n_embd / n_head; + + ms.add_kv(LLM_KV_GENERAL_ARCHITECTURE, llm_arch_name(arch)); + ms.add_kv(LLM_KV_VOCAB_SIZE, n_vocab); + ms.add_kv(LLM_KV_CONTEXT_LENGTH, n_ctx); + ms.add_kv(LLM_KV_EMBEDDING_LENGTH, n_embd); + ms.add_kv(LLM_KV_FEED_FORWARD_LENGTH, n_ff); + ms.add_kv(LLM_KV_BLOCK_COUNT, n_layer); + ms.add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, 1e-5f); + ms.add_kv(LLM_KV_ROPE_DIMENSION_COUNT, n_embd_head); + ms.add_kv(LLM_KV_TOKENIZER_MODEL, "no_vocab"); + + std::vector n_head_per_layer; + n_head_per_layer.reserve(n_layer); + for (uint32_t il = 0; il < n_layer; il++) { + n_head_per_layer.push_back(il == 1 ? 0 : n_head); + } + ms.add_kv(LLM_KV_ATTENTION_HEAD_COUNT, n_head_per_layer); + ms.add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, n_head_per_layer); + + ms.add_kv(LLM_KV_SSM_CONV_KERNEL, uint32_t(4)); + ms.add_kv(LLM_KV_SSM_INNER_SIZE, 2 * n_embd); + ms.add_kv(LLM_KV_SSM_STATE_SIZE, uint32_t(16)); + ms.add_kv(LLM_KV_SSM_TIME_STEP_RANK, n_embd / 16); + + return ret; +} + +static int test_save_load_models(const size_t seed) { + std::vector models_to_test; + // Add more models to test here. + models_to_test.push_back(transformer_model()); + models_to_test.push_back(recurrent_model()); + models_to_test.push_back(hybrid_model()); + + bool all_ok = true; + for (const gguf_context_ptr & gguf_ctx : models_to_test) { + all_ok = all_ok && test_save_and_load_state(gguf_ctx, seed); + } + return all_ok ? 0 : 1; +} + +int main(int argc, char ** argv) { + common_init(); + std::random_device rd; + + size_t seed = rd(); + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "-s") == 0 || strcmp(argv[i], "--seed") == 0) { + if (i + 1 < argc) { + seed = std::stoull(argv[++i]); + } else { + return 1; + } + } + } + + try { + return test_save_load_models(seed); + } catch (const std::exception & err) { + fprintf(stderr, "encountered runtime error: %s\n", err.what()); + return -1; + } +}