diff --git a/common/arg.cpp b/common/arg.cpp index 62d31393c4..edf7a92750 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1397,7 +1397,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, bool value) { params.warmup = value; } - ).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY})); + ).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_DEBUG})); add_opt(common_arg( {"--spm-infill"}, string_format( @@ -1706,7 +1706,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; } else { throw std::invalid_argument("invalid value"); } } - ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING")); + ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG}).set_env("LLAMA_ARG_POOLING")); add_opt(common_arg( {"--attention"}, "{causal,non-causal}", "attention type for embeddings, use model default if unspecified", @@ -2579,7 +2579,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, int value) { params.embd_normalize = value; } - ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); + ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_DEBUG})); add_opt(common_arg( {"--embd-output-format"}, "FORMAT", "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)", @@ -2657,7 +2657,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.embedding = true; } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS")); + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG}).set_env("LLAMA_ARG_EMBEDDINGS")); add_opt(common_arg( {"--rerank", "--reranking"}, string_format("enable reranking endpoint on server (default: %s)", "disabled"), @@ -3344,6 +3344,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } } ).set_examples({ LLAMA_EXAMPLE_FINETUNE })); + add_opt(common_arg( + {"--save-logits"}, + string_format("save final logits to files for verification (default: %s)", params.save_logits ? "true" : "false"), + [](common_params & params) { + params.save_logits = true; + } + ).set_examples({LLAMA_EXAMPLE_DEBUG})); + add_opt(common_arg( + {"--logits-output-dir"}, "PATH", + string_format("directory for saving logits output files (default: %s)", params.logits_output_dir.c_str()), + [](common_params & params, const std::string & value) { + params.logits_output_dir = value; + } + ).set_examples({LLAMA_EXAMPLE_DEBUG})); + add_opt(common_arg( + {"--tensor-filter"}, "REGEX", + "filter tensor names for debug output (regex pattern, can be specified multiple times)", + [](common_params & params, const std::string & value) { + params.tensor_filter.push_back(value); + } + ).set_examples({LLAMA_EXAMPLE_DEBUG})); // presets add_opt(common_arg( diff --git a/common/common.h b/common/common.h index f8bc686b6f..34b8968d96 100644 --- a/common/common.h +++ b/common/common.h @@ -80,6 +80,7 @@ int32_t cpu_get_num_math(); // enum llama_example { + LLAMA_EXAMPLE_DEBUG, LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_COMPLETION, @@ -370,6 +371,11 @@ struct common_params { std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT std::string logits_file = ""; // file for saving *all* logits // NOLINT + // llama-debug specific options + std::string logits_output_dir = "data"; // directory for saving logits output files // NOLINT + bool save_logits = false; // whether to save logits to files // NOLINT + std::vector tensor_filter; // filter tensor names for debug output (regex) // NOLINT + std::vector in_files; // all input files std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector kv_overrides; diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 91797cf78a..a29dc707c3 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -15,6 +15,7 @@ llama_add_compile_flags() if (EMSCRIPTEN) else() add_subdirectory(batched) + add_subdirectory(debug) add_subdirectory(embedding) add_subdirectory(eval-callback) @@ -34,7 +35,6 @@ else() add_subdirectory(gen-docs) add_subdirectory(training) add_subdirectory(diffusion) - add_subdirectory(model-conversion) if (NOT GGML_BACKEND_DL) add_subdirectory(convert-llama2c-to-ggml) # these examples use the backends directly and cannot be built with dynamic loading diff --git a/examples/model-conversion/CMakeLists.txt b/examples/debug/CMakeLists.txt similarity index 73% rename from examples/model-conversion/CMakeLists.txt rename to examples/debug/CMakeLists.txt index fc1746ce45..34593072be 100644 --- a/examples/model-conversion/CMakeLists.txt +++ b/examples/debug/CMakeLists.txt @@ -1,5 +1,5 @@ -set(TARGET llama-logits) -add_executable(${TARGET} logits.cpp) +set(TARGET llama-debug) +add_executable(${TARGET} debug.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/debug/README.md b/examples/debug/README.md new file mode 100644 index 0000000000..28e00c9342 --- /dev/null +++ b/examples/debug/README.md @@ -0,0 +1,54 @@ +# llama.cpp/examples/debug + +This is a utility intended to help debug a model by registering a callback that +logs GGML operations and tensor data. It can also store the generated logits or +embeddings as well as the prompt and token ids for comparision with the original +model. + +### Usage + +```shell +llama-debug \ + --hf-repo ggml-org/models \ + --hf-file phi-2/ggml-model-q4_0.gguf \ + --model phi-2-q4_0.gguf \ + --prompt hello \ + --save-logits \ + --verbose +``` +The tensor data is logged as debug and required the --verbose flag. The reason +for this is that while useful for a model with many layers there can be a lot of +output. You can filter the tensor names using the `--tensor-filter` option. + +A recommended approach is to first run without `--verbose` and see if the +generated logits/embeddings are close to the original model. If they are not, +then it might be required to inspect tensor by tensor and in that case it is +useful to enable the `--verbose` flag along with `--tensor-filter` to focus on +specific tensors. + +### Options +This example supports all standard `llama.cpp` options and also accepts the +following options: +```console +$ llama-debug --help +... + +----- example-specific params ----- + +--save-logits save final logits to files for verification (default: false) +--logits-output-dir PATH directory for saving logits output files (default: data) +--tensor-filter REGEX filter tensor names for debug output (regex pattern, can be specified multiple times) +``` + +### Output Files + +When `--save-logits` is enabled, the following files are created in the output +directory: + +* `llamacpp-[-embeddings].bin` - Binary output (logits or embeddings) +* `llamacpp-[-embeddings].txt` - Text output (logits or embeddings, one per line) +* `llamacpp-[-embeddings]-prompt.txt` - Prompt text and token IDs +* `llamacpp-[-embeddings]-tokens.bin` - Binary token IDs for programmatic comparison + +These files can be compared against the original model's output to verify the +converted model. diff --git a/examples/debug/debug.cpp b/examples/debug/debug.cpp new file mode 100644 index 0000000000..bca55981ce --- /dev/null +++ b/examples/debug/debug.cpp @@ -0,0 +1,399 @@ +#include "arg.h" +#include "common.h" +#include "log.h" +#include "llama.h" +#include "ggml.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data); + +struct callback_data { + std::vector data; + std::vector tensor_filters; + + callback_data() = default; + + callback_data(common_params & params, const std::vector & filter_patterns) { + for (const auto & pattern : filter_patterns) { + try { + std::string anchored_pattern = "^" + pattern; + tensor_filters.emplace_back(anchored_pattern, std::regex::optimize); + } catch (const std::regex_error & e) { + throw std::runtime_error("Invalid regex pattern '" + pattern + "': " + e.what()); + } + } + params.cb_eval = ggml_debug; + params.cb_eval_user_data = this; + } +}; + +struct output_data { + float * data_ptr = nullptr; + int data_size = 0; + std::string type_suffix; + std::vector storage; + std::string prompt; + std::vector tokens; + + output_data(llama_context * ctx, const llama_model * model, const common_params & params) { + const llama_vocab * vocab = llama_model_get_vocab(model); + const bool add_bos = llama_vocab_get_add_bos(vocab); + + tokens = common_tokenize(ctx, params.prompt, add_bos); + prompt = params.prompt; + + if (params.embedding) { + const int n_embd = llama_model_n_embd(model); + const bool pooling_enabled = llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE; + const int n_embd_count = pooling_enabled ? 1 : tokens.size(); + const int n_embeddings = n_embd * n_embd_count; + + float * embeddings; + if (pooling_enabled) { + embeddings = llama_get_embeddings_seq(ctx, 0); + storage.resize(n_embeddings); + common_embd_normalize(embeddings, storage.data(), n_embeddings, params.embd_normalize); + embeddings = storage.data(); + } else { + embeddings = llama_get_embeddings(ctx); + } + + data_ptr = embeddings; + data_size = n_embeddings; + type_suffix = "-embeddings"; + } else { + const float * logits = llama_get_logits_ith(ctx, tokens.size() - 1); + const int n_logits = llama_vocab_n_tokens(vocab); + + data_ptr = const_cast(logits); + data_size = n_logits; + type_suffix = ""; + } + } +}; + +static std::string ggml_ne_string(const ggml_tensor * t) { + std::string str; + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + str += std::to_string(t->ne[i]); + if (i + 1 < GGML_MAX_DIMS) { + str += ", "; + } + } + return str; +} + +static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) { + union { + float f; + uint32_t i; + } u; + u.i = (uint32_t)h.bits << 16; + return u.f; +} + +static float ggml_get_float_value(const uint8_t * data, ggml_type type, + const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) { + size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; + switch (type) { + case GGML_TYPE_F16: + return ggml_fp16_to_fp32(*(const ggml_fp16_t *) &data[i]); + case GGML_TYPE_F32: + return *(const float *) &data[i]; + case GGML_TYPE_I64: + return (float) *(const int64_t *) &data[i]; + case GGML_TYPE_I32: + return (float) *(const int32_t *) &data[i]; + case GGML_TYPE_I16: + return (float) *(const int16_t *) &data[i]; + case GGML_TYPE_I8: + return (float) *(const int8_t *) &data[i]; + case GGML_TYPE_BF16: + return ggml_compute_bf16_to_fp32(*(const ggml_bf16_t *) &data[i]); + default: + GGML_ABORT("fatal error"); + } +} + +static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) { + GGML_ASSERT(n > 0); + float sum = 0; + float sum_sq = 0.0; + for (int64_t i3 = 0; i3 < ne[3]; i3++) { + for (int64_t i2 = 0; i2 < ne[2]; i2++) { + for (int64_t i1 = 0; i1 < ne[1]; i1++) { + for (int64_t i0 = 0; i0 < ne[0]; i0++) { + const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3); + sum += v; + sum_sq += v * v; + } + } + } + } + for (int64_t i3 = 0; i3 < ne[3]; i3++) { + LOG_DBG(" [\n"); + for (int64_t i2 = 0; i2 < ne[2]; i2++) { + if (i2 == n && ne[2] > 2*n) { + LOG_DBG(" ..., \n"); + i2 = ne[2] - n; + } + LOG_DBG(" [\n"); + for (int64_t i1 = 0; i1 < ne[1]; i1++) { + if (i1 == n && ne[1] > 2*n) { + LOG_DBG(" ..., \n"); + i1 = ne[1] - n; + } + LOG_DBG(" ["); + for (int64_t i0 = 0; i0 < ne[0]; i0++) { + if (i0 == n && ne[0] > 2*n) { + LOG_DBG("..., "); + i0 = ne[0] - n; + } + const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3); + LOG_DBG("%12.4f", v); + if (i0 < ne[0] - 1) { + LOG_DBG(", "); + } + } + LOG_DBG("],\n"); + } + LOG_DBG(" ],\n"); + } + LOG_DBG(" ]\n"); + LOG_DBG(" sum = %f\n", sum); + LOG_DBG(" sum_sq = %f\n", sum_sq); + } + + if (std::isnan(sum)) { + LOG_ERR("encountered NaN - aborting\n"); + exit(0); + } +} + +/** + * GGML operations callback during the graph execution. + * + * @param t current tensor + * @param ask when ask is true, the scheduler wants to know if we are interested in data from this tensor + * if we return true, a follow-up call will be made with ask=false in which we can do the actual collection. + * see ggml_backend_sched_eval_callback + * @param user_data user data to pass at each call back + * @return true to receive data or continue the graph, false otherwise + */ +static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { + auto * cb_data = (callback_data *) user_data; + + const struct ggml_tensor * src0 = t->src[0]; + const struct ggml_tensor * src1 = t->src[1]; + + if (ask) { + return true; // Always retrieve data + } + + bool matches_filter = cb_data->tensor_filters.empty(); + + if (!matches_filter) { + for (const auto & filter : cb_data->tensor_filters) { + if (std::regex_search(t->name, filter)) { + matches_filter = true; + break; + } + } + } + + char src1_str[128] = {0}; + if (src1) { + snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str()); + } + + if (matches_filter) { + LOG_DBG("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, + t->name, + ggml_type_name(t->type), + ggml_op_desc(t), + src0->name, + ggml_ne_string(src0).c_str(), + src1 ? src1_str : "", + ggml_ne_string(t).c_str()); + } + + const bool is_host = ggml_backend_buffer_is_host(t->buffer); + + if (!is_host) { + auto n_bytes = ggml_nbytes(t); + cb_data->data.resize(n_bytes); + ggml_backend_tensor_get(t, cb_data->data.data(), 0, n_bytes); + } + + if (!ggml_is_quantized(t->type) && matches_filter) { + uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data(); + ggml_print_tensor(data, t->type, t->ne, t->nb, 3); + } + + return true; +} + + +static void save_output_data(const output_data & output, const std::string & model_name, const std::string & output_dir) { + std::filesystem::create_directory(output_dir); + auto base_path = std::filesystem::path{output_dir} / ("llamacpp-" + model_name + output.type_suffix); + + // Save logits/embeddings to binary file. + { + std::filesystem::path filepath{base_path.string() + ".bin"}; + std::ofstream file{filepath, std::ios::binary}; + if (!file) { + throw std::runtime_error("failed to open binary output file: " + filepath.string()); + } + file.write(reinterpret_cast(output.data_ptr), output.data_size * sizeof(float)); + LOG("Data saved to %s\n", filepath.c_str()); + } + + // Save logits/embeddings to text file. + { + std::filesystem::path filepath{base_path.string() + ".txt"}; + std::ofstream file{filepath}; + if (!file) { + throw std::runtime_error("failed to open text output file: " + filepath.string()); + } + for (int i = 0; i < output.data_size; i++) { + file << i << ": " << output.data_ptr[i] << '\n'; + } + LOG("Data saved to %s\n", filepath.c_str()); + } + + // Save prompt and tokens to text file. + { + std::filesystem::path filepath{base_path.string() + "-prompt.txt"}; + std::ofstream file{filepath}; + if (!file) { + throw std::runtime_error("failed to open prompt output file: " + filepath.string()); + } + + file << "prompt: " << output.prompt << '\n'; + file << "n_tokens: " << output.tokens.size() << '\n'; + + file << "token ids: "; + for (size_t i = 0; i < output.tokens.size(); i++) { + file << output.tokens[i]; + if (i + 1 < output.tokens.size()) { + file << ", "; + } + } + file << '\n'; + LOG("Prompt saved to %s\n", filepath.c_str()); + } + + // Save token ids to binary file. + { + std::filesystem::path filepath{base_path.string() + "-tokens.bin"}; + std::ofstream file{filepath, std::ios::binary}; + if (!file) { + throw std::runtime_error("failed to open tokens binary file: " + filepath.string()); + } + file.write(reinterpret_cast(output.tokens.data()), output.tokens.size() * sizeof(llama_token)); + LOG("Tokens saved to %s\n", filepath.c_str()); + } + +} + +static void print_tokenized_prompt(llama_context * ctx, const std::vector & tokens, const std::string & prompt) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + LOG("Model add_bos: %s\n", llama_vocab_get_add_bos(vocab) ? "true" : "false"); + LOG("Input prompt: \"%s\"\n", prompt.c_str()); + LOG("Token ids (%zu):\n", tokens.size()); + + for (auto id : tokens) { + std::string piece(128, '\0'); + int n = llama_token_to_piece(vocab, id, piece.data(), piece.size(), 0, true); + if (n < 0) { + LOG_ERR("failed to convert token %d to piece\n", id); + continue; + } + piece.resize(n); + LOG("%s(%d) ", piece.c_str(), id); + } + LOG("\n"); +} + +static bool run(llama_context * ctx, const common_params & params) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const bool add_bos = llama_vocab_get_add_bos(vocab); + + std::vector tokens = common_tokenize(ctx, params.prompt, add_bos); + + if (tokens.empty()) { + LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__); + return false; + } + + if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { + LOG_ERR("%s : failed to eval\n", __func__); + return false; + } + + print_tokenized_prompt(ctx, tokens, params.prompt); + + if (params.save_logits) { + output_data output {ctx, model, params}; + std::filesystem::path model_path{params.model.path}; + std::string model_name{model_path.stem().string()}; + save_output_data(output, model_name, params.logits_output_dir); + } + + return true; +} + +int main(int argc, char ** argv) { + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DEBUG)) { + return 1; + } + + common_init(); + + llama_backend_init(); + llama_numa_init(params.numa); + + callback_data cb_data(params, params.tensor_filter); + + auto llama_init = common_init_from_params(params); + + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); + + if (model == nullptr || ctx == nullptr) { + LOG_ERR("%s : failed to init\n", __func__); + return 1; + } + + { + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + LOG_INF("\n"); + } + + if (!run(ctx, params)) { + return 1; + } + + LOG("\n"); + llama_perf_context_print(ctx); + + llama_backend_free(); + + return 0; +} diff --git a/examples/model-conversion/logits.cpp b/examples/model-conversion/logits.cpp deleted file mode 100644 index 5bcf063267..0000000000 --- a/examples/model-conversion/logits.cpp +++ /dev/null @@ -1,268 +0,0 @@ -#include "llama.h" -#include "common.h" - - -#include -#include -#include -#include -#include -#include - -static void print_usage(int, char ** argv) { - printf("\nexample usage:\n"); - printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [-pooling] [-embd-norm ] [prompt]\n", argv[0]); - printf("\n"); - printf(" -embd-norm: normalization type for pooled embeddings (default: 2)\n"); - printf(" -1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm\n"); - printf("\n"); -} - -int main(int argc, char ** argv) { - std::string model_path; - std::string prompt = "Hello, my name is"; - int ngl = 0; - bool embedding_mode = false; - bool pooling_enabled = false; - int32_t embd_norm = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) - - { - int i = 1; - for (; i < argc; i++) { - if (strcmp(argv[i], "-m") == 0) { - if (i + 1 < argc) { - model_path = argv[++i]; - } else { - print_usage(argc, argv); - return 1; - } - } else if (strcmp(argv[i], "-ngl") == 0) { - if (i + 1 < argc) { - try { - ngl = std::stoi(argv[++i]); - } catch (...) { - print_usage(argc, argv); - return 1; - } - } else { - print_usage(argc, argv); - return 1; - } - } else if (strcmp(argv[i], "-embd-mode") == 0) { - embedding_mode = true; - } else if (strcmp(argv[i], "-pooling") == 0) { - pooling_enabled = true; - } else if (strcmp(argv[i], "-embd-norm") == 0) { - if (i + 1 < argc) { - try { - embd_norm = std::stoi(argv[++i]); - } catch (...) { - print_usage(argc, argv); - return 1; - } - } else { - print_usage(argc, argv); - return 1; - } - } else { - // prompt starts here - break; - } - } - - if (model_path.empty()) { - print_usage(argc, argv); - return 1; - } - - if (i < argc) { - prompt = argv[i++]; - for (; i < argc; i++) { - prompt += " "; - prompt += argv[i]; - } - } - } - - ggml_backend_load_all(); - llama_model_params model_params = llama_model_default_params(); - model_params.n_gpu_layers = ngl; - - llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params); - - if (model == NULL) { - fprintf(stderr , "%s: error: unable to load model\n" , __func__); - return 1; - } - - // Extract basename from model_path - const char * basename = strrchr(model_path.c_str(), '/'); - basename = (basename == NULL) ? model_path.c_str() : basename + 1; - - char model_name[256]; - strncpy(model_name, basename, 255); - model_name[255] = '\0'; - - char * dot = strrchr(model_name, '.'); - if (dot != NULL && strcmp(dot, ".gguf") == 0) { - *dot = '\0'; - } - printf("Model name: %s\n", model_name); - - const llama_vocab * vocab = llama_model_get_vocab(model); - const int n_prompt = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true); - - std::vector prompt_tokens(n_prompt); - if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) { - fprintf(stderr, "%s: error: failed to tokenize the prompt\n", __func__); - return 1; - } - - llama_context_params ctx_params = llama_context_default_params(); - ctx_params.n_ctx = n_prompt; - ctx_params.n_batch = n_prompt; - ctx_params.no_perf = false; - if (embedding_mode) { - ctx_params.embeddings = true; - ctx_params.pooling_type = pooling_enabled ? LLAMA_POOLING_TYPE_MEAN : LLAMA_POOLING_TYPE_NONE; - ctx_params.n_ubatch = ctx_params.n_batch; - } - - llama_context * ctx = llama_init_from_model(model, ctx_params); - if (ctx == NULL) { - fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); - return 1; - } - - printf("Input prompt: \"%s\"\n", prompt.c_str()); - printf("Tokenized prompt (%d tokens): ", n_prompt); - for (auto id : prompt_tokens) { - char buf[128]; - int n = llama_token_to_piece(vocab, id, buf, sizeof(buf), 0, true); - if (n < 0) { - fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__); - return 1; - } - std::string s(buf, n); - printf("%s (%d)", s.c_str(), id); - } - printf("\n"); - - llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); - - if (llama_decode(ctx, batch)) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return 1; - } - - float * data_ptr; - int data_size; - const char * type; - std::vector embd_out; - - if (embedding_mode) { - const int n_embd = llama_model_n_embd(model); - const int n_embd_count = pooling_enabled ? 1 : batch.n_tokens; - const int n_embeddings = n_embd * n_embd_count; - float * embeddings; - type = "-embeddings"; - - if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) { - embeddings = llama_get_embeddings_seq(ctx, 0); - embd_out.resize(n_embeddings); - printf("Normalizing embeddings using norm: %d\n", embd_norm); - common_embd_normalize(embeddings, embd_out.data(), n_embeddings, embd_norm); - embeddings = embd_out.data(); - } else { - embeddings = llama_get_embeddings(ctx); - } - - printf("Embedding dimension: %d\n", n_embd); - printf("\n"); - - // Print embeddings in the specified format - for (int j = 0; j < n_embd_count; j++) { - printf("embedding %d: ", j); - - // Print first 3 values - for (int i = 0; i < 3 && i < n_embd; i++) { - printf("%9.6f ", embeddings[j * n_embd + i]); - } - - printf(" ... "); - - // Print last 3 values - for (int i = n_embd - 3; i < n_embd; i++) { - if (i >= 0) { - printf("%9.6f ", embeddings[j * n_embd + i]); - } - } - - printf("\n"); - } - printf("\n"); - - printf("Embeddings size: %d\n", n_embeddings); - - data_ptr = embeddings; - data_size = n_embeddings; - } else { - float * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); - const int n_logits = llama_vocab_n_tokens(vocab); - type = ""; - printf("Vocab size: %d\n", n_logits); - - data_ptr = logits; - data_size = n_logits; - } - - std::filesystem::create_directory("data"); - - // Save data to binary file - char bin_filename[512]; - snprintf(bin_filename, sizeof(bin_filename), "data/llamacpp-%s%s.bin", model_name, type); - printf("Saving data to %s\n", bin_filename); - - FILE * f = fopen(bin_filename, "wb"); - if (f == NULL) { - fprintf(stderr, "%s: error: failed to open binary output file\n", __func__); - return 1; - } - fwrite(data_ptr, sizeof(float), data_size, f); - fclose(f); - - // Also save as text for debugging - char txt_filename[512]; - snprintf(txt_filename, sizeof(txt_filename), "data/llamacpp-%s%s.txt", model_name, type); - f = fopen(txt_filename, "w"); - if (f == NULL) { - fprintf(stderr, "%s: error: failed to open text output file\n", __func__); - return 1; - } - for (int i = 0; i < data_size; i++) { - fprintf(f, "%d: %.6f\n", i, data_ptr[i]); - } - fclose(f); - - if (!embedding_mode) { - printf("First 10 logits: "); - for (int i = 0; i < 10 && i < data_size; i++) { - printf("%.6f ", data_ptr[i]); - } - printf("\n"); - - printf("Last 10 logits: "); - for (int i = data_size - 10; i < data_size; i++) { - if (i >= 0) printf("%.6f ", data_ptr[i]); - } - printf("\n\n"); - } - - printf("Data saved to %s\n", bin_filename); - printf("Data saved to %s\n", txt_filename); - - llama_free(ctx); - llama_model_free(model); - - return 0; -} diff --git a/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py b/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py index 55ad821385..4ab778fbc7 100755 --- a/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py +++ b/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py @@ -67,7 +67,7 @@ with torch.no_grad(): last_hidden_states = outputs.hidden_states[-1] # Get embeddings for all tokens - token_embeddings = last_hidden_states[0].cpu().numpy() # Remove batch dimension + token_embeddings = last_hidden_states[0].float().cpu().numpy() # Remove batch dimension print(f"Hidden states shape: {last_hidden_states.shape}") print(f"Token embeddings shape: {token_embeddings.shape}") diff --git a/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh b/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh index fa16a02c65..3cce3fc94d 100755 --- a/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh +++ b/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh @@ -13,6 +13,6 @@ if [ -z "$CONVERTED_MODEL" ]; then exit 1 fi -cmake --build ../../build --target llama-logits -j8 +cmake --build ../../build --target llama-debug -j8 -../../build/bin/llama-logits -m $CONVERTED_MODEL -embd-mode "Hello world today" +../../build/bin/llama-debug -m $CONVERTED_MODEL --embedding -p "Hello world today" --save-logits diff --git a/examples/model-conversion/scripts/causal/run-converted-model.sh b/examples/model-conversion/scripts/causal/run-converted-model.sh index 529e9987b0..b6c3d38662 100755 --- a/examples/model-conversion/scripts/causal/run-converted-model.sh +++ b/examples/model-conversion/scripts/causal/run-converted-model.sh @@ -21,6 +21,6 @@ fi echo $CONVERTED_MODEL echo $MODEL_TESTING_PROMPT -cmake --build ../../build --target llama-logits -j8 +cmake --build ../../build --target llama-debug -j8 -../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT" +../../build/bin/llama-debug -m "$CONVERTED_MODEL" -p "$MODEL_TESTING_PROMPT" --save-logits diff --git a/examples/model-conversion/scripts/embedding/run-converted-model.sh b/examples/model-conversion/scripts/embedding/run-converted-model.sh index 0f490e6c3b..5d264b0663 100755 --- a/examples/model-conversion/scripts/embedding/run-converted-model.sh +++ b/examples/model-conversion/scripts/embedding/run-converted-model.sh @@ -50,10 +50,9 @@ fi echo $CONVERTED_MODEL -cmake --build ../../build --target llama-logits -j8 -# TODO: update logits.cpp to accept a --file/-f option for the prompt +cmake --build ../../build --target llama-debug -j8 if [ -n "$USE_POOLING" ]; then - ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode -pooling "$PROMPT" + ../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding --pooling mean -p "$PROMPT" --save-logits else - ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT" + ../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding --pooling none -p "$PROMPT" --save-logits fi