#include "ggml.h" #include "llama.h" #include "get-model.h" #include "common.h" #ifdef NDEBUG #undef NDEBUG #endif #include #include #include #include #include #include #include struct test_model_context { llama_model * model = nullptr; llama_context * ctx = nullptr; const llama_vocab * vocab = nullptr; int n_vocab = 0; std::unordered_map seq_positions; std::unordered_map last_batch_info; bool load_model(const char * model_path) { if (model != nullptr) { return true; } llama_backend_init(); model = llama_model_load_from_file(model_path, llama_model_default_params()); if (model == nullptr) { fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", model_path); cleanup(); return false; } vocab = llama_model_get_vocab(model); n_vocab = llama_vocab_n_tokens(vocab); fprintf(stderr, "Vocabulary size: %d\n", n_vocab); return true; } bool setup(const char * model_path, std::vector & configs, int32_t n_seq_max = -1) { if (model == nullptr) { load_model(model_path); } if (ctx != nullptr) { return true; } llama_context_params cparams = llama_context_default_params(); cparams.n_ctx = 512; cparams.n_batch = 512; cparams.samplers = configs.data(); cparams.n_samplers = configs.size(); // If n_seq_max is not specified, calculate it from configs if (n_seq_max < 0) { int32_t max_seq_id = 0; for (const auto & config : configs) { if (config.seq_id > max_seq_id) { max_seq_id = config.seq_id; } } cparams.n_seq_max = max_seq_id + 1; } else { cparams.n_seq_max = n_seq_max; } ctx = llama_init_from_model(model, cparams); if (ctx == nullptr) { fprintf(stderr, "Warning: failed to create context, skipping test\n"); cleanup(); return false; } llama_set_warmup(ctx, false); return true; } bool decode(const std::map & prompts) { if (ctx == nullptr || vocab == nullptr) { fprintf(stderr, "Error: context not initialized, call setup() first\n"); return false; } last_batch_info.clear(); llama_batch batch = llama_batch_init(512, 0, prompts.size()); int n_tokens_per_prompt = 0; for (const auto & [seq_id, prompt] : prompts) { std::vector tokens; tokens.push_back(llama_vocab_bos(vocab)); std::vector prompt_tokens(32); int n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.length(), prompt_tokens.data(), prompt_tokens.size(), false, false); //TODO: refactor this function to just handle a single prompt at a time // to avoid this check and complexity. if (n_tokens_per_prompt == 0) { n_tokens_per_prompt = n_tokens; } else { if (n_tokens != n_tokens_per_prompt) { fprintf(stderr, "Error: prompts must have the same number of tokens\n"); llama_batch_free(batch); return false; } n_tokens_per_prompt = n_tokens; } if (n_tokens < 0) { fprintf(stderr, "Warning: tokenization failed for seq_id %d\n", seq_id); llama_batch_free(batch); return false; } for (int i = 0; i < n_tokens; i++) { tokens.push_back(prompt_tokens[i]); } for (size_t i = 0; i < tokens.size(); i++) { common_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1); } seq_positions[seq_id] = tokens.size(); } printf("Batch contents:\n"); printf("n_tokens: %d\n", batch.n_tokens); for (int i = 0; i < batch.n_tokens; i++) { printf("token[%d]: tok=%-5d, pos=%d, n_seq_id=%d, seq_ids=[", i, batch.token[i], batch.pos[i], batch.n_seq_id[i]); for (int j = 0; j < batch.n_seq_id[i]; j++) { printf("%d%s", batch.seq_id[i][j], j < batch.n_seq_id[i]-1 ? ", " : ""); } printf("], logits=%d\n", batch.logits[i]); } if (llama_decode(ctx, batch) != 0) { fprintf(stderr, "Warning: llama_decode failed\n"); llama_batch_free(batch); return false; } // Build mapping from seq id to batch token idx for (int i = 0; i < batch.n_tokens; i++) { if (batch.logits[i]) { llama_seq_id seq_id = batch.seq_id[i][0]; last_batch_info[seq_id] = i; } } llama_batch_free(batch); return true; } int32_t idx_for_seq(llama_seq_id seq_id) { auto it = last_batch_info.find(seq_id); if (it == last_batch_info.end()) { fprintf(stderr, "Error: no batch index found for seq_id %d\n", seq_id); return -1; } return it->second; } bool decode_token(llama_token token, llama_seq_id seq_id = 0) { if (ctx == nullptr) { fprintf(stderr, "Error: context not initialized, call setup() first\n"); return false; } llama_batch batch = llama_batch_init(1, 0, 1); int32_t pos = seq_positions[seq_id]; common_batch_add(batch, token, pos, { seq_id }, true); if (llama_decode(ctx, batch) != 0) { fprintf(stderr, "Warning: llama_decode failed for token %d in seq %d\n", token, seq_id); llama_batch_free(batch); return false; } last_batch_info.clear(); for (int i = 0; i < batch.n_tokens; i++) { if (batch.logits[i]) { llama_seq_id cur_seq = batch.seq_id[i][0]; last_batch_info[cur_seq] = i; } } seq_positions[seq_id]++; llama_batch_free(batch); return true; } bool decode_tokens(const std::map & seq_tokens) { if (ctx == nullptr) { fprintf(stderr, "Error: context not initialized, call setup() first\n"); return false; } llama_batch batch = llama_batch_init(seq_tokens.size(), 0, seq_tokens.size()); for (const auto & [seq_id, token] : seq_tokens) { int32_t pos = seq_positions[seq_id]; common_batch_add(batch, token, pos, { seq_id }, true); } if (llama_decode(ctx, batch) != 0) { fprintf(stderr, "Warning: llama_decode failed for batch tokens\n"); llama_batch_free(batch); return false; } for (const auto & [seq_id, _] : seq_tokens) { seq_positions[seq_id]++; } last_batch_info.clear(); for (int i = 0; i < batch.n_tokens; i++) { if (batch.logits[i]) { llama_seq_id cur_seq = batch.seq_id[i][0]; last_batch_info[cur_seq] = i; } } llama_batch_free(batch); return true; } std::string token_to_piece(llama_token token, bool special) { std::string piece; piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); if (n_chars < 0) { piece.resize(-n_chars); int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); GGML_ASSERT(check == -n_chars); } else { piece.resize(n_chars); } return piece; } void reset() { if (ctx) { llama_free(ctx); ctx = nullptr; } seq_positions.clear(); last_batch_info.clear(); } void cleanup() { if (ctx) { llama_free(ctx); } if (model) { llama_model_free(model); } llama_backend_free(); ctx = nullptr; model = nullptr; vocab = nullptr; } ~test_model_context() { cleanup(); } }; static void test_backend_greedy_sampling(const char * model_path) { test_model_context test_ctx; const int seq_id = 0; struct llama_sampler_chain_params backend_sampler_params = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_sampler_params); llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_greedy()); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; } if (!test_ctx.decode({{seq_id, "Some"}})) { GGML_ASSERT(false && "Failed to decode token"); } int32_t batch_idx = test_ctx.idx_for_seq(seq_id); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); token = llama_get_sampled_token_ith(test_ctx.ctx, -1); printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); for (int i = 0; i < 10; i++) { int32_t loop_idx = test_ctx.idx_for_seq(seq_id); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, loop_idx); printf("Generation step %d: token id:%d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); if (!test_ctx.decode_token(token, 0)) { GGML_ASSERT(false && "Failed to decode token"); } } } static void test_backend_top_k_sampling(const char * model_path) { test_model_context test_ctx; const int seq_id = 0; const int32_t k = 8; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_top_k(k)); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; } if (!test_ctx.decode({{seq_id, "Hello"}})) { GGML_ASSERT(false && "Failed to decode token"); } int32_t batch_idx = test_ctx.idx_for_seq(seq_id); float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx); uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); for (size_t i = 0; i < n_logits; ++i) { printf("top_k logit[%zu] = %.6f\n", i, logits[i]); } llama_token * candidates = llama_get_sampled_candidates_ith(test_ctx.ctx, batch_idx); uint32_t n_candidates = llama_get_sampled_candidates_count_ith(test_ctx.ctx, batch_idx); for (size_t i = 0; i < n_candidates; ++i) { printf("top_k candidate[%zu] = %d : %s\n", i, candidates[i], test_ctx.token_to_piece(candidates[i], false).c_str()); } // Sample using CPU sampler for verification that it is possible to do hybrid // sampling, first top_k on the backend and then dist on the CPU. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); struct llama_sampler * chain = llama_sampler_chain_init(chain_params); GGML_ASSERT(chain->iface->backend_apply != nullptr); llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); const std::string token_str = test_ctx.token_to_piece(token, false); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); printf("backend top-k hybrid sampling test PASSED\n"); llama_sampler_free(chain); } static void test_backend_temp_sampling(const char * model_path) { test_model_context test_ctx; { const float temp_0 = 0.8f; struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain_0 = llama_sampler_chain_init(backend_chain_params_0); llama_sampler_chain_add(backend_sampler_chain_0, llama_sampler_init_temp(temp_0)); const float temp_1 = 0.1f; struct llama_sampler_chain_params backend_chain_params_1 = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain_1 = llama_sampler_chain_init(backend_chain_params_1); llama_sampler_chain_add(backend_sampler_chain_1, llama_sampler_init_temp(temp_1)); std::vector backend_sampler_configs = { { 0, backend_sampler_chain_0 }, { 1, backend_sampler_chain_1 } }; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; } if (!test_ctx.decode({{0, "Some where over"}, {1, "Once upon a"}})) { GGML_ASSERT(false && "Failed to decode token"); } // Verfify sequence 0 { int32_t batch_idx = test_ctx.idx_for_seq(0); int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); GGML_ASSERT(n_logits == test_ctx.n_vocab); // Sample from sequence 0 using CPU sampler struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); struct llama_sampler * chain = llama_sampler_chain_init(chain_params); llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); const std::string token_str = test_ctx.token_to_piece(token, false); printf("Sequence 0 sampled token id:%d, string: '%s'\n", token, token_str.c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); llama_sampler_free(chain); } // Verfify sequence 1 { int32_t batch_idx = test_ctx.idx_for_seq(1); // Sample from sequence 1 using CPU sampler struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); struct llama_sampler * chain = llama_sampler_chain_init(chain_params); llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); const std::string token_str = test_ctx.token_to_piece(token, false); printf("Sequence 1 sampled token id:%d, string: '%s'\n", token, token_str.c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); llama_sampler_free(chain); } } // lambda to testing non-positive temperature values. auto test_argmax_temp = [&](float temp) { printf("\nTesting temperature = %.1f\n", temp); test_ctx.reset(); int seq_id = 0; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp(temp)); std::vector backend_sampler_configs = { { seq_id, backend_sampler_chain }, }; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; } if (!test_ctx.decode({{seq_id, "Once"}})) { GGML_ASSERT(false && "Failed to decode token"); } int32_t batch_idx = test_ctx.idx_for_seq(seq_id); uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); GGML_ASSERT(n_logits == 1); }; test_argmax_temp(0.0f); test_argmax_temp(-1.0f); printf("backend temp sampling test PASSED\n"); } static void test_backend_temp_ext_sampling(const char * model_path) { test_model_context test_ctx; { int seq_id = 0; const float temp = 0.8f; const float delta = 0.5f; const float exponent = 1.5f; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp_ext(temp, delta, exponent)); std::vector backend_sampler_configs = { { seq_id, backend_sampler_chain }, }; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; } if (!test_ctx.decode({{seq_id, "Once upon a"}})) { GGML_ASSERT(false && "Failed to decode token"); } // Verify sequence 0 { int32_t batch_idx = test_ctx.idx_for_seq(seq_id); int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); GGML_ASSERT(n_logits == test_ctx.n_vocab); } } test_ctx.reset(); // lambda to testing non-positive temp/delta/exponent values. auto test_argmax_temp = [&](float temp, float delta, float exponent) { printf("\nTesting temperature = %.1f, delta = %1.f, exponent = %1.f\n", temp, delta, exponent); test_ctx.reset(); int seq_id = 0; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp_ext(temp, delta, exponent)); std::vector backend_sampler_configs = { { seq_id, backend_sampler_chain }, }; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; } if (!test_ctx.decode({{seq_id, "Once"}})) { GGML_ASSERT(false && "Failed to decode token"); } int32_t batch_idx = test_ctx.idx_for_seq(seq_id); uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); if (temp <= 0.0f && delta >= 0.0f) { GGML_ASSERT(n_logits == 1); } else { GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab); } }; test_argmax_temp(0.0f, 0.3f, 1.0f); // Greedy (temp=0) test_argmax_temp(-1.0f, 0.3f, 2.0f); // Greedy (temp<0) test_argmax_temp(0.8f, 0.0f, 2.0f); // Temperature scaling printf("backend temp_ext sampling test PASSED\n"); } static void test_backend_min_p_sampling(const char * model_path) { test_model_context test_ctx; const int seq_id = 0; const float p = 0.1; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_min_p(p, 0)); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; } if (!test_ctx.decode({{seq_id, "Hello"}})) { GGML_ASSERT(false && "Failed to decode token"); } int32_t batch_idx = test_ctx.idx_for_seq(seq_id); float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx); uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); // Print the logits that are above the min-p threshold std::vector filtered_logits; for (size_t i = 0; i < n_logits; ++i) { if (logits[i] > -1e9f) { filtered_logits.push_back(logits[i]); //printf("min_p logit[%zu] = %.6f\n", i, logits[i]); } } GGML_ASSERT(filtered_logits.size() < (size_t) test_ctx.n_vocab); // Sample using CPU sampler for verification to inspect they are reasonable struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); struct llama_sampler * chain = llama_sampler_chain_init(chain_params); llama_sampler_chain_add(chain, llama_sampler_init_dist(88)); llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); const std::string token_str = test_ctx.token_to_piece(token, false); printf("min-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); // Decode and sampler 10 more tokens for (int i = 0; i < 10; i++) { int32_t loop_idx = test_ctx.idx_for_seq(seq_id); llama_token token = llama_sampler_sample(chain, test_ctx.ctx, loop_idx); printf("min-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); if (!test_ctx.decode_token(token, 0)) { GGML_ASSERT(false && "Failed to decode token"); } } printf("min-p sampling test PASSED\n"); llama_sampler_free(chain); } static void test_backend_top_p_sampling(const char * model_path) { test_model_context test_ctx; const int seq_id = 0; const float p = 0.9; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_top_p(p, 0)); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; } if (!test_ctx.decode({{seq_id, "Hello"}})) { return; } int32_t batch_idx = test_ctx.idx_for_seq(seq_id); float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx); uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); // Print the logits that are above the min-p threshold std::vector filtered_logits; for (size_t i = 0; i < n_logits; ++i) { if (logits[i] > -1e9f) { filtered_logits.push_back(logits[i]); } } GGML_ASSERT(filtered_logits.size() < (size_t) test_ctx.n_vocab); GGML_ASSERT(filtered_logits.size() > 0); // Sample using CPU sampler for verification to inspect they are reasonable struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); struct llama_sampler * chain = llama_sampler_chain_init(chain_params); llama_sampler_chain_add(chain, llama_sampler_init_dist(88)); llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); const std::string token_str = test_ctx.token_to_piece(token, false); printf("top-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); // Decode and sampler 10 more tokens for (int i = 0; i < 10; i++) { int32_t loop_idx = test_ctx.idx_for_seq(seq_id); llama_token token = llama_sampler_sample(chain, test_ctx.ctx, loop_idx); printf("top-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); test_ctx.decode_token(token, 0); } printf("top-p sampling test PASSED\n"); llama_sampler_free(chain); } static void test_backend_multi_sequence_sampling(const char * model_path) { test_model_context test_ctx; struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0); llama_sampler_chain_add(sampler_chain_0, llama_sampler_init_greedy()); struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params(); struct llama_sampler * sampler_chain_1 = llama_sampler_chain_init(chain_params_1); llama_sampler_chain_add(sampler_chain_1, llama_sampler_init_temp(0.8f)); llama_sampler_chain_add(sampler_chain_1, llama_sampler_init_greedy()); std::vector backend_sampler_configs = { { 0, sampler_chain_0 }, { 1, sampler_chain_1 } }; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; } std::map prompts = { {0, "Hello"}, {1, "Some"} }; if (!test_ctx.decode(prompts)) { GGML_ASSERT(false && "Failed to decode token"); } // Verfiy sequence 0 { int32_t batch_idx = test_ctx.idx_for_seq(0); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); const std::string token_str = test_ctx.token_to_piece(token, false); printf("Seq 0 sampled token id=%d, string='%s'\n", token, token_str.c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); } // Verify sequence 1 { int32_t batch_idx= test_ctx.idx_for_seq(1); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); const std::string token_str = test_ctx.token_to_piece(token, false); printf("Seq 1 sampled token id=%d, string='%s'\n", token, token_str.c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); } // Generate tokens for each sequence printf("\nMulti-sequence generation:\n"); for (int step = 0; step < 4; step++) { std::map tokens; for (llama_seq_id seq_id : {0, 1}) { int32_t idx = test_ctx.idx_for_seq(seq_id); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, idx); const std::string token_str = test_ctx.token_to_piece(token, false); printf(" Seq %d, step %d: token id=%d, string='%s'\n", seq_id, step, token, token_str.c_str()); tokens[seq_id] = token; } // Decode all tokens in a single batch if (!test_ctx.decode_tokens(tokens)) { GGML_ASSERT(false && "Failed to decode token"); } } printf("backend multi-sequence sampling test PASSED\n"); } static void test_backend_dist_sampling(const char * model_path) { test_model_context test_ctx; const int seq_id = 189; const int32_t seed = 88; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed)); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; } if (!test_ctx.decode({{seq_id, "Some"}})) { GGML_ASSERT(false && "Failed to decode token"); } int32_t batch_idx = test_ctx.idx_for_seq(seq_id); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); //GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr); token = llama_get_sampled_token_ith(test_ctx.ctx, -1); printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); } static void test_backend_dist_sampling_and_cpu(const char * model_path) { test_model_context test_ctx; const int seq_id = 0; const int32_t seed = 88; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed)); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; } if (!test_ctx.decode({{seq_id, "Some"}})) { GGML_ASSERT(false && "Failed to decode token"); } int32_t batch_idx = test_ctx.idx_for_seq(seq_id); // Sample using CPU sampler struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); struct llama_sampler * chain = llama_sampler_chain_init(chain_params); llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); llama_token cpu_token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); printf("dist & cpu sampled id:%d, string:'%s'\n", cpu_token, test_ctx.token_to_piece(cpu_token, false).c_str()); GGML_ASSERT(backend_token == cpu_token); } static void test_backend_logit_bias_sampling(const char * model_path) { test_model_context test_ctx; // Calling load_model to ensure vocab is loaded and can be accessed if (!test_ctx.load_model(model_path)) { return; } const int seq_id = 0; // Create the logit biases vector. std::vector logit_bias; // Get the token for the piece "World". const std::string piece = "World"; std::vector tokens(16); llama_tokenize(test_ctx.vocab, piece.c_str(), piece.size(), tokens.data(), tokens.size(), false, false); llama_token bias_token = tokens[0]; logit_bias.push_back({ bias_token, +100.0f }); printf("biasing token piece '%s' -> token id %d\n", piece.c_str(), bias_token); struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_logit_bias( llama_vocab_n_tokens(test_ctx.vocab), logit_bias.size(), logit_bias.data())); llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(88)); std::vector backend_sampler_configs = { { seq_id, backend_sampler_chain }, }; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; } if (!test_ctx.decode({{seq_id, "Hello"}})) { GGML_ASSERT(false && "Failed to decode token"); } llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(seq_id)); const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false); printf("logit bias sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str()); GGML_ASSERT(backend_token == bias_token); } // This test verifies that it is possible to have two different backend sampler, // one that uses the backend dist sampler, and another that uses CPU dist sampler. static void test_backend_mixed_sampling(const char * model_path) { test_model_context test_ctx; struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0); llama_sampler_chain_add(sampler_chain_0, llama_sampler_init_dist(88)); int k = 40; struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params(); struct llama_sampler * sampler_chain_1 = llama_sampler_chain_init(chain_params_1); llama_sampler_chain_add(sampler_chain_1, llama_sampler_init_top_k(k)); std::vector backend_sampler_configs = { { 0, sampler_chain_0 }, { 1, sampler_chain_1 } }; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; } std::map prompts = { {0, "Hello"}, {1, "Some"} }; if (!test_ctx.decode(prompts)) { GGML_ASSERT(false && "Failed to decode token"); } // Verfiy sequence 0 that used the dist backend sampler. { int32_t batch_idx = test_ctx.idx_for_seq(0); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); const std::string token_str = test_ctx.token_to_piece(token, false); printf("sampled token id=%d, string='%s'\n", token, token_str.c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); //GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr); //GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx) == 0); } // Verfiy sequence 1 that used the top-k backend sampler. { int32_t batch_idx = test_ctx.idx_for_seq(1); float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx); GGML_ASSERT(logits != nullptr); size_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); GGML_ASSERT(n_logits == (size_t) k); GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx, batch_idx) == LLAMA_TOKEN_NULL); } printf("backend mixed sampling test PASSED\n"); } static void test_backend_set_sampler(const char * model_path) { test_model_context test_ctx; const int32_t seed = 88; const int seq_id = 0; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed)); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; } if (!test_ctx.decode({{seq_id, "Hello"}})) { GGML_ASSERT(false && "Failed to decode token"); } int32_t batch_idx = test_ctx.idx_for_seq(seq_id); // Sample using backend sampler configured above llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false); printf("dist sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str()); // Now clear the backend sampler for this sequence. llama_set_sampler(test_ctx.ctx, seq_id, nullptr); printf("Cleared backend sampler for seq_id %d\n", seq_id); // Sample using CPU sampler struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); struct llama_sampler * chain = llama_sampler_chain_init(chain_params); llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); std::map tokens = { { seq_id, backend_token}, }; if (!test_ctx.decode_tokens(tokens)) { GGML_ASSERT(false && "Failed to decode token"); } // Should not have any sampled token or probs after clearing the backend sampler. const int32_t idx = test_ctx.idx_for_seq(seq_id); GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx, idx) == LLAMA_TOKEN_NULL); GGML_ASSERT(llama_get_sampled_probs_ith(test_ctx.ctx, idx) == nullptr); // Sample the token using the CPU sampler chain. llama_token token2 = llama_sampler_sample(chain, test_ctx.ctx, seq_id); const std::string token2_str = test_ctx.token_to_piece(token2, false); printf("CPU sampled token after clearing backend sampler: id=%d, string='%s'\n", token2, token2_str.c_str()); std::map tokens2 = { { seq_id, token2}, }; // Set a new backend sampler for the sequence. struct llama_sampler_chain_params new_backend_chain_params = llama_sampler_chain_default_params(); struct llama_sampler * new_backend_sampler_chain = llama_sampler_chain_init(new_backend_chain_params); llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_init_top_k(20)); llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_init_dist(seed)); llama_set_sampler(test_ctx.ctx, seq_id, new_backend_sampler_chain); if (!test_ctx.decode_tokens(tokens2)) { GGML_ASSERT(false && "Failed to decode token"); } llama_token new_backend_token = llama_get_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(seq_id)); const std::string new_backend_token_str = test_ctx.token_to_piece(new_backend_token, false); printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str()); } static void test_backend_cpu_mixed_batch(const char * model_path) { test_model_context test_ctx; // Sequence 0 uses backend sampling struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0); llama_sampler_chain_add(sampler_chain_0, llama_sampler_init_dist(88)); std::vector backend_sampler_configs = { { 0, sampler_chain_0 }, }; // We need 2 sequences: seq 0 with backend sampling, seq 1 with CPU sampling if (!test_ctx.setup(model_path, backend_sampler_configs, 2)) { return; } std::map prompts = { {0, "Hello"}, // Will use backend sampling {1, "Some"} // Will use CPU sampling }; if (!test_ctx.decode(prompts)) { GGML_ASSERT(false && "Failed to decode token"); } // Verify sequence 0 (backend sampled) { int32_t batch_idx = test_ctx.idx_for_seq(0); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); const std::string token_str = test_ctx.token_to_piece(token, false); printf("Seq 0 (backend) sampled token id=%d, string='%s'\n", token, token_str.c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); } // Verify sequence 1 (CPU sampled) { int32_t batch_idx = test_ctx.idx_for_seq(1); llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); GGML_ASSERT(backend_token == LLAMA_TOKEN_NULL); struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); struct llama_sampler * chain = llama_sampler_chain_init(chain_params); llama_sampler_chain_add(chain, llama_sampler_init_greedy()); llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); const std::string token_str = test_ctx.token_to_piece(token, false); printf("Seq 1 (CPU) sampled token id=%d, string='%s'\n", token, token_str.c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); llama_sampler_free(chain); } // Clear/remove the backend sampler, and sample again { // clear the backend sampler for seq 0 so that there are no backend // samplers. llama_set_sampler(test_ctx.ctx, 0, nullptr); // Create a CPU sampler and verify we can sampler from it. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); struct llama_sampler * chain = llama_sampler_chain_init(chain_params); llama_sampler_chain_add(chain, llama_sampler_init_greedy()); int32_t batch_idx = test_ctx.idx_for_seq(1); llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); if (!test_ctx.decode_token(token, 1)) { GGML_ASSERT(false && "Failed to decode token"); } llama_sampler_free(chain); } // Set a backend sampler so that we can verify that it can be reset { struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); struct llama_sampler * sampler_chain= llama_sampler_chain_init(chain_params); llama_sampler_chain_add(sampler_chain, llama_sampler_init_dist(88)); llama_set_sampler(test_ctx.ctx, 0, sampler_chain); if (!test_ctx.decode_token(3834, 0)) { GGML_ASSERT(false && "Failed to decode token"); } int32_t batch_idx = test_ctx.idx_for_seq(0); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); const std::string token_str = test_ctx.token_to_piece(token, false); printf("re-added backend sampled token id=%d, string='%s'\n", token, token_str.c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); } printf("backend-cpu mixed batch test PASSED\n"); } static void test_backend_max_outputs(const char * model_path) { test_model_context test_ctx; const int seq_id = 0; const int32_t seed = 88; llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed)); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { return; } llama_batch batch = llama_batch_init(512, 0, 1); std::string prompt = "Hello"; std::vector tokens; tokens.push_back(llama_vocab_bos(test_ctx.vocab)); std::vector prompt_tokens(32); int n_tokens = llama_tokenize(test_ctx.vocab, prompt.c_str(), prompt.length(), prompt_tokens.data(), prompt_tokens.size(), false, false); for (int i = 0; i < n_tokens; i++) { tokens.push_back(prompt_tokens[i]); } for (size_t i = 0; i < tokens.size(); i++) { // set all tokens as output to trigger error common_batch_add(batch, tokens[i], i, { seq_id }, true); } printf(">>> test_max_outputs expected error start:\n"); const int ret = llama_decode(test_ctx.ctx, batch); GGML_ASSERT(ret != 0 && "llama_decode should not succeed multiple outputs per sequence"); printf("<<< test_max_outputs expected error end.\n"); llama_batch_free(batch); } struct backend_test_case { const char * name; void (*fn)(const char *); bool enabled_by_default; }; static const backend_test_case BACKEND_TESTS[] = { { "greedy", test_backend_greedy_sampling, true }, { "logit_bias", test_backend_logit_bias_sampling, true }, { "temp", test_backend_temp_sampling, true }, { "temp_ext", test_backend_temp_ext_sampling, true }, { "top_k", test_backend_top_k_sampling, true }, { "multi_sequence", test_backend_multi_sequence_sampling, true }, { "dist", test_backend_dist_sampling, true }, { "dist_and_cpu", test_backend_dist_sampling_and_cpu, true }, { "set_sampler", test_backend_set_sampler, true }, { "max_outputs", test_backend_max_outputs, true }, { "mixed", test_backend_mixed_sampling, true }, { "min_p", test_backend_min_p_sampling, true }, { "cpu_mixed", test_backend_cpu_mixed_batch, true }, { "top_p", test_backend_top_p_sampling, true }, }; struct backend_cli_args { const char * model = nullptr; const char * test = nullptr; }; static backend_cli_args parse_backend_cli(int argc, char ** argv) { backend_cli_args out; for (int i = 1; i < argc; ++i) { const char * arg = argv[i]; if (std::strcmp(arg, "--test") == 0) { if (i + 1 >= argc) { fprintf(stderr, "--test expects a value\n"); exit(EXIT_FAILURE); } out.test = argv[++i]; continue; } if (std::strncmp(arg, "--test=", 7) == 0) { out.test = arg + 7; continue; } if (std::strcmp(arg, "--model") == 0) { if (i + 1 >= argc) { fprintf(stderr, "--model expects a value\n"); exit(EXIT_FAILURE); } out.model = argv[++i]; continue; } if (std::strncmp(arg, "--model=", 8) == 0) { out.model = arg + 8; continue; } if (!out.model) { out.model = arg; continue; } fprintf(stderr, "Unexpected argument: %s\n", arg); exit(EXIT_FAILURE); } return out; } static std::vector collect_tests_to_run(const char * requested) { std::vector selected; if (requested != nullptr) { for (const auto & test : BACKEND_TESTS) { if (std::strcmp(test.name, requested) == 0) { selected.push_back(&test); break; } } if (selected.empty()) { fprintf(stderr, "Unknown test '%s'. Available tests:\n", requested); for (const auto & test : BACKEND_TESTS) { fprintf(stderr, " %s\n", test.name); } exit(EXIT_FAILURE); } } else { for (const auto & test : BACKEND_TESTS) { if (test.enabled_by_default) { selected.push_back(&test); } } } if (selected.empty()) { fprintf(stderr, "No backend sampling tests selected. Use --test= to pick one.\n"); } return selected; } static void run_tests(const std::vector & tests, const char * model_path) { for (const auto * test : tests) { fprintf(stderr, "\n=== %s ===\n", test->name); test->fn(model_path); } } int main(int argc, char *argv[] ) { const backend_cli_args args = parse_backend_cli(argc, argv); std::array model_argv { argv[0], const_cast(args.model) }; const int model_argc = args.model ? 2 : 1; char * model_path = get_model_or_exit(model_argc, model_argv.data()); auto * file = fopen(model_path, "r"); if (file == nullptr) { fprintf(stderr, "no model at '%s' found\n", model_path); return EXIT_FAILURE; } fprintf(stderr, "using '%s'\n", model_path); fclose(file); ggml_time_init(); const std::vector tests = collect_tests_to_run(args.test); if (!tests.empty()) { run_tests(tests, model_path); } return 0; }