This commit is contained in:
Daniel Bevenius 2026-04-01 06:58:53 +02:00 committed by GitHub
commit 45e3fe1154
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 328 additions and 0 deletions

View File

@ -209,6 +209,7 @@ llama_build_and_test(
peg-parser/tests.h
)
llama_build_and_test(test-regex-partial.cpp)
llama_build_and_test(test-save-load-state.cpp)
if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
set(MODEL_NAME "tinyllamas/stories15M-q4_0.gguf")

View File

@ -0,0 +1,327 @@
#include "common.h"
#include "log.h"
#include "ggml-backend.h"
#include "ggml.h"
#include "gguf.h"
#include "ggml-cpp.h"
#include "llama.h"
#include "llama-cpp.h"
#include "../src/llama-arch.h"
#include "../src/llama-model-saver.h"
#include <cinttypes>
#include <cstdio>
#include <cstring>
#include <cstdint>
#include <random>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
// Taken from test-llama-arch.cpp
static void set_tensor_data(struct ggml_tensor * tensor, void * userdata) {
std::hash<std::string> hasher;
std::mt19937 gen(hasher(tensor->name) + *(const size_t *) userdata);
std::normal_distribution<float> dis(0.0f, 1.0e-2f);
const int64_t ne = ggml_nelements(tensor);
if (tensor->type == GGML_TYPE_F32) {
std::vector<float> tmp(ne);
for (int64_t i = 0; i < ne; i++) {
tmp[i] = dis(gen);
}
ggml_backend_tensor_set(tensor, tmp.data(), 0, ggml_nbytes(tensor));
} else if (tensor->type == GGML_TYPE_F16) {
std::vector<ggml_fp16_t> tmp(ne);
for (int64_t i = 0; i < ne; i++) {
tmp[i] = ggml_fp32_to_fp16(dis(gen));
}
ggml_backend_tensor_set(tensor, tmp.data(), 0, ggml_nbytes(tensor));
} else {
GGML_ABORT("fatal error");
}
}
// Taken from test-llama-arch.cpp
static std::vector<llama_token> get_tokens(const uint32_t n_tokens, const uint32_t n_vocab, const size_t seed){
std::mt19937 gen(seed);
std::uniform_int_distribution<> dis(0, n_vocab - 1);
std::vector<llama_token> ret;
ret.reserve(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
ret.push_back(dis(gen));
}
return ret;
}
// Taken from test-llama-arch.cpp
static std::pair<llama_model_ptr, llama_context_ptr> get_model_and_ctx(
struct gguf_context * gguf_ctx, const size_t seed) {
llama_model_params model_params = llama_model_default_params();
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = 0; // will be set from model
ctx_params.n_threads = 4;
ctx_params.n_threads_batch = 4;
size_t tmp = seed;
llama_model_ptr model(llama_model_init_from_user(gguf_ctx, set_tensor_data, &tmp, model_params));
if (!model) {
throw std::runtime_error("failed to create llama model");
}
llama_context_ptr lctx(llama_init_from_model(model.get(), ctx_params));
if (!lctx) {
throw std::runtime_error("failed to create llama context");
}
return std::make_pair(std::move(model), std::move(lctx));
}
static bool compare_logits(const std::vector<float> & a, const std::vector<float> & b) {
if (a.size() != b.size()) {
return false;
}
const float threshold = 2e-5f; // Relaxed threshold for state save/load numerical precision
for (size_t i = 0; i < a.size(); i++) {
if (std::abs(a[i] - b[i]) > threshold) {
return false;
}
}
return true;
}
static std::vector<float> get_logits_from_context(llama_context * lctx) {
const int n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(llama_get_model(lctx)));
std::vector<float> logits;
logits.reserve(n_vocab);
const float * logits_ith = llama_get_logits_ith(lctx, -1);
for (int j = 0; j < n_vocab; j++) {
logits.push_back(logits_ith[j]);
}
return logits;
}
static bool test_save_and_load_state(const gguf_context_ptr & gguf_ctx, int seed) {
const int key_idx = gguf_find_key(gguf_ctx.get(), "general.architecture");
const char * arch_name = (key_idx == -1) ? "unknown" : gguf_get_val_str(gguf_ctx.get(), key_idx);
const int vocab_key_idx = gguf_find_key(gguf_ctx.get(), "llama.vocab_size");
const uint32_t n_vocab = (vocab_key_idx == -1) ? 128 : gguf_get_val_u32(gguf_ctx.get(), vocab_key_idx);
const std::vector<llama_token> session_tokens = get_tokens(16, n_vocab, seed);
const char * session_file = "test_session.tmp";
int n_ctx = 0;
bool ok = true;
try {
auto model_and_ctx = get_model_and_ctx(gguf_ctx.get(), seed);
llama_model * model = model_and_ctx.first.get();
std::vector<float> logits1;
std::vector<float> logits2;
// Decode a few tokens and save the session state.
{
llama_context * ctx = model_and_ctx.second.get();
llama_batch batch = llama_batch_init(session_tokens.size(), 0, 1);
for (size_t i = 0; i < session_tokens.size(); ++i) {
common_batch_add(batch, session_tokens[i], i, {0}, i == (session_tokens.size() -1));
}
if (llama_decode(ctx, batch) != 0) {
throw std::runtime_error("llama_decode failed");
}
llama_batch_free(batch);
logits1 = get_logits_from_context(ctx);
n_ctx = llama_n_ctx(ctx);
llama_state_save_file(ctx, session_file, session_tokens.data(), session_tokens.size());
}
// Create a new llama_context and load and restore the session state
{
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = n_ctx;
ctx_params.n_threads = 4;
ctx_params.n_threads_batch = 4;
llama_context * ctx = llama_init_from_model(model, ctx_params);
std::vector<llama_token> loaded_tokens(session_tokens.size());
size_t n_loaded_tokens = 0;
if (llama_state_load_file(ctx, session_file, loaded_tokens.data(), loaded_tokens.size(), &n_loaded_tokens) != 1) {
throw std::runtime_error("llama_state_load_file failed");
}
if (n_loaded_tokens != session_tokens.size()) {
throw std::runtime_error("loaded incorrect number of tokens");
}
loaded_tokens.resize(n_loaded_tokens);
if (loaded_tokens != session_tokens) {
throw std::runtime_error("loaded session tokens do not match");
}
llama_memory_t mem = llama_get_memory(ctx);
fprintf(stderr, "Before replay: KV cache seq 0 max pos = %d\n", llama_memory_seq_pos_max(mem, 0));
if (!common_replay_last_token(ctx, loaded_tokens.back(), n_loaded_tokens)) {
throw std::runtime_error("failed to replay last token");
}
fprintf(stderr, "After replay: KV cache seq 0 max pos = %d\n", llama_memory_seq_pos_max(mem, 0));
logits2 = get_logits_from_context(ctx);
// Verify we can continue decoding after load
llama_token next_token = get_tokens(1, n_vocab, seed + 100)[0];
llama_batch batch = llama_batch_init(1, 0, 1);
common_batch_add(batch, next_token, n_loaded_tokens + 1, {0}, true);
if (llama_decode(ctx, batch) != 0) {
llama_batch_free(batch);
throw std::runtime_error("failed to decode next token after load");
}
llama_batch_free(batch);
llama_free(ctx);
}
// Verify the logits from the original and the restore session state.
if (!compare_logits(logits1, logits2)) {
ok = false;
}
} catch (const std::exception & e) {
fprintf(stderr, "Exception during test for %s: %s\n", arch_name, e.what());
ok = false;
}
std::remove(session_file);
fprintf(stderr, "Test save_load_state for arch '%s': %s\n", arch_name, ok ? "PASSED" : "FAILED");
return ok;
}
static gguf_context_ptr transformer_model() {
const llm_arch arch = LLM_ARCH_LLAMA;
gguf_context_ptr ret{gguf_init_empty()};
llama_model_saver ms{arch, ret.get()};
const uint32_t n_ctx = 128;
const uint32_t n_vocab = 50;
const uint32_t n_embd = 256;
const uint32_t n_head = 2;
const uint32_t n_ff = 384;
const uint32_t n_layer = 2;
const uint32_t n_embd_head = n_embd / n_head;
ms.add_kv(LLM_KV_GENERAL_ARCHITECTURE, llm_arch_name(arch));
ms.add_kv(LLM_KV_VOCAB_SIZE, n_vocab);
ms.add_kv(LLM_KV_CONTEXT_LENGTH, n_ctx);
ms.add_kv(LLM_KV_EMBEDDING_LENGTH, n_embd);
ms.add_kv(LLM_KV_FEED_FORWARD_LENGTH, n_ff);
ms.add_kv(LLM_KV_BLOCK_COUNT, n_layer);
ms.add_kv(LLM_KV_ATTENTION_HEAD_COUNT, n_head);
ms.add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, n_head);
ms.add_kv(LLM_KV_ROPE_DIMENSION_COUNT, n_embd_head);
ms.add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, 1e-5f);
ms.add_kv(LLM_KV_TOKENIZER_MODEL, "no_vocab");
return ret;
}
static gguf_context_ptr recurrent_model() {
const llm_arch arch = LLM_ARCH_MAMBA;
gguf_context_ptr ret{gguf_init_empty()};
llama_model_saver ms{arch, ret.get()};
const uint32_t n_ctx = 128;
const uint32_t n_vocab = 128;
const uint32_t n_embd = 256;
const uint32_t n_layer = 2;
ms.add_kv(LLM_KV_GENERAL_ARCHITECTURE, llm_arch_name(arch));
ms.add_kv(LLM_KV_VOCAB_SIZE, n_vocab);
ms.add_kv(LLM_KV_CONTEXT_LENGTH, n_ctx);
ms.add_kv(LLM_KV_EMBEDDING_LENGTH, n_embd);
ms.add_kv(LLM_KV_BLOCK_COUNT, n_layer);
ms.add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, 1e-5f);
ms.add_kv(LLM_KV_SSM_CONV_KERNEL, uint32_t(4));
ms.add_kv(LLM_KV_SSM_INNER_SIZE, 2 * n_embd);
ms.add_kv(LLM_KV_SSM_STATE_SIZE, uint32_t(16));
ms.add_kv(LLM_KV_SSM_TIME_STEP_RANK, n_embd / 16);
ms.add_kv(LLM_KV_TOKENIZER_MODEL, "no_vocab");
return ret;
}
static gguf_context_ptr hybrid_model() {
const llm_arch arch = LLM_ARCH_JAMBA;
gguf_context_ptr ret{gguf_init_empty()};
llama_model_saver ms{arch, ret.get()};
const uint32_t n_ctx = 128;
const uint32_t n_vocab = 128;
const uint32_t n_embd = 256;
const uint32_t n_head = 2;
const uint32_t n_ff = 384;
const uint32_t n_layer = 4;
const uint32_t n_embd_head = n_embd / n_head;
ms.add_kv(LLM_KV_GENERAL_ARCHITECTURE, llm_arch_name(arch));
ms.add_kv(LLM_KV_VOCAB_SIZE, n_vocab);
ms.add_kv(LLM_KV_CONTEXT_LENGTH, n_ctx);
ms.add_kv(LLM_KV_EMBEDDING_LENGTH, n_embd);
ms.add_kv(LLM_KV_FEED_FORWARD_LENGTH, n_ff);
ms.add_kv(LLM_KV_BLOCK_COUNT, n_layer);
ms.add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, 1e-5f);
ms.add_kv(LLM_KV_ROPE_DIMENSION_COUNT, n_embd_head);
ms.add_kv(LLM_KV_TOKENIZER_MODEL, "no_vocab");
std::vector<uint32_t> n_head_per_layer;
n_head_per_layer.reserve(n_layer);
for (uint32_t il = 0; il < n_layer; il++) {
n_head_per_layer.push_back(il == 1 ? 0 : n_head);
}
ms.add_kv(LLM_KV_ATTENTION_HEAD_COUNT, n_head_per_layer);
ms.add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, n_head_per_layer);
ms.add_kv(LLM_KV_SSM_CONV_KERNEL, uint32_t(4));
ms.add_kv(LLM_KV_SSM_INNER_SIZE, 2 * n_embd);
ms.add_kv(LLM_KV_SSM_STATE_SIZE, uint32_t(16));
ms.add_kv(LLM_KV_SSM_TIME_STEP_RANK, n_embd / 16);
return ret;
}
static int test_save_load_models(const size_t seed) {
std::vector<gguf_context_ptr> models_to_test;
// Add more models to test here.
models_to_test.push_back(transformer_model());
models_to_test.push_back(recurrent_model());
models_to_test.push_back(hybrid_model());
bool all_ok = true;
for (const gguf_context_ptr & gguf_ctx : models_to_test) {
all_ok = all_ok && test_save_and_load_state(gguf_ctx, seed);
}
return all_ok ? 0 : 1;
}
int main(int argc, char ** argv) {
common_init();
std::random_device rd;
size_t seed = rd();
for (int i = 1; i < argc; i++) {
if (strcmp(argv[i], "-s") == 0 || strcmp(argv[i], "--seed") == 0) {
if (i + 1 < argc) {
seed = std::stoull(argv[++i]);
} else {
return 1;
}
}
}
try {
return test_save_load_models(seed);
} catch (const std::exception & err) {
fprintf(stderr, "encountered runtime error: %s\n", err.what());
return -1;
}
}