From ffba4f29e6a9ed7165ea6b94150856c5b49925cb Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 7 Jan 2026 10:42:19 +0100 Subject: [PATCH 01/27] examples : add debug utility/example (#18464) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * examples : add debug utility/example This commit introduces a new example named llama-debug which is a utility that is intended to be used to assist with developing/debugging a converted model. The motivation for this utilitiy is to assist in model conversion work to verify that the model produces the expected outputs. It is intended to replace logits.cpp in examples/model-conversion. Example usage: ```console ./build/bin/llama-debug \ -m models/Qwen2.5-0.5B-Instruct.gguf \ --prompt "Hello, my name is" \ --save-logits ... Model add_bos: false Input prompt: "Hello, my name is" Token ids (5): Hello(9707) ,(11) my(847) name(829) is(374) Data saved to data/llamacpp-Qwen2.5-0.5B-Instruct.bin Data saved to data/llamacpp-Qwen2.5-0.5B-Instruct.txt Prompt saved to data/llamacpp-Qwen2.5-0.5B-Instruct-prompt.txt Tokens saved to data/llamacpp-Qwen2.5-0.5B-Instruct-tokens.bin ``` For more details about the options available for this example, please refer to examples/debug/README.md. * throw runtime error instead of logging error * remove params.warmup and enable the warmup/nowarmup option * model-conversion : remove logits.cpp This commit removes logits.cpp in favor of using llama-debug for generating logits and embeddings. * examples : remove model-conversion directory This was missed in the previous commit. * model-conversion : add support for saving prompt and token ids This commit add support for storing the prompt and the token ids for the prompt when running the original models. The motivation for this is that this will allow us to compare the prompt and the tokens generated for the prompt when verifing the converted model. Currently it is possible that even if the same prompt is used that the tokens generated are different if there is a difference in the tokenization between the original and converted model which would currently go unnoticed (the verification will most likely fail but it might not be obvious why). * squash! model-conversion : add support for saving prompt and token ids fix pyright errors. * model-conversion : add compare_tokens utility This commit adds a script to compare token outputs between original and converted models. Example usage: ```console (venv) $ ./scripts/utils/compare_tokens.py pytorch-gemma-3-270m-it llamacpp-gemma-3-270m-it-bf16 Comparing tokens between: Original : pytorch-gemma-3-270m-it (6 tokens) Converted: llamacpp-gemma-3-270m-it-bf16 (6 tokens) āœ… All 6 tokens match! ``` And there is a verbose flag that will also print out the prompts: ```console (venv) $ ./scripts/utils/compare_tokens.py pytorch-gemma-3-270m-it llamacpp-gemma-3-270m-it-bf16 -v Original model prompt (pytorch-gemma-3-270m-it): prompt: Hello, my name is n_tokens: 6 token ids: 2, 9259, 236764, 1041, 1463, 563 Converted model prompt (llamacpp-gemma-3-270m-it-bf16): prompt: Hello, my name is n_tokens: 6 token ids: 2, 9259, 236764, 1041, 1463, 563 Comparing tokens between: Original : pytorch-gemma-3-270m-it (6 tokens) Converted: llamacpp-gemma-3-270m-it-bf16 (6 tokens) āœ… All 6 tokens match! ``` * model-conversion : add token comparison to verifiction scripts This commit add the calling of the compare_tokens function in compare-logits.py and semantic_check.py to ensure that the token ids that the tokenizers procoduce are the same before proceeding with verifying the logits/embeddings. Placing them in the existing scripts instead calling them separately ensures that the token comparison is always done prior to the logit/embedding verifications. Follow up commit/pr could refactor the causal logits verification into a single script instead of the two that exist now. This would reduce the code and make it consistent with the embeddings verficiation which only has a single script. * debug : use llama_model_n_embd_out This commit updates the debug example to use the new function llama_model_n_embd_out instead of llama_model_n_embd. The motivation for this change is to support late interation retriever models, like LFM2-ColBert-350M, where the output embeddings are down projected to a lower dimension. * debug : add print_usage function This commit adds a print_usage function that is passed to the common_params_parse. The motivation for this is that this enables a specific usage message which will be printed after all the options, for example: ```console example usage: Print tensors: ./build/bin/llama-debug -m model.gguf -p "Hello my name is" --verbose The tensors to be printed can be filtered with --tensor-filter option. Save logits/embeddings: ./build/bin/llama-debug -m model.gguf -p "Hello my name is" --save-logits Add --embedding to save embeddings ``` --- common/arg.cpp | 29 +- common/common.h | 6 + examples/CMakeLists.txt | 2 +- .../CMakeLists.txt | 4 +- examples/debug/README.md | 54 +++ examples/debug/debug.cpp | 421 ++++++++++++++++++ examples/model-conversion/logits.cpp | 268 ----------- .../scripts/causal/compare-logits.py | 9 +- .../causal/run-casual-gen-embeddings-org.py | 2 +- .../run-converted-model-embeddings-logits.sh | 4 +- .../scripts/causal/run-converted-model.sh | 4 +- .../scripts/causal/run-org-model.py | 20 +- .../scripts/embedding/run-converted-model.sh | 7 +- .../scripts/embedding/run-original-model.py | 25 +- .../model-conversion/scripts/utils/common.py | 95 ++++ .../scripts/utils/compare_tokens.py | 76 ++++ .../scripts/utils/semantic_check.py | 18 + 17 files changed, 725 insertions(+), 319 deletions(-) rename examples/{model-conversion => debug}/CMakeLists.txt (73%) create mode 100644 examples/debug/README.md create mode 100644 examples/debug/debug.cpp delete mode 100644 examples/model-conversion/logits.cpp create mode 100755 examples/model-conversion/scripts/utils/compare_tokens.py diff --git a/common/arg.cpp b/common/arg.cpp index c3610d262b..a67a26e2dc 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1445,7 +1445,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, bool value) { params.warmup = value; } - ).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY})); + ).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_DEBUG})); add_opt(common_arg( {"--spm-infill"}, string_format( @@ -1761,7 +1761,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; } else { throw std::invalid_argument("invalid value"); } } - ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING")); + ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG}).set_env("LLAMA_ARG_POOLING")); add_opt(common_arg( {"--attention"}, "{causal,non-causal}", "attention type for embeddings, use model default if unspecified", @@ -2609,7 +2609,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, int value) { params.embd_normalize = value; } - ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); + ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_DEBUG})); add_opt(common_arg( {"--embd-output-format"}, "FORMAT", "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)", @@ -2687,7 +2687,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.embedding = true; } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS")); + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG}).set_env("LLAMA_ARG_EMBEDDINGS")); add_opt(common_arg( {"--rerank", "--reranking"}, string_format("enable reranking endpoint on server (default: %s)", "disabled"), @@ -3378,6 +3378,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } } ).set_examples({ LLAMA_EXAMPLE_FINETUNE })); + add_opt(common_arg( + {"--save-logits"}, + string_format("save final logits to files for verification (default: %s)", params.save_logits ? "true" : "false"), + [](common_params & params) { + params.save_logits = true; + } + ).set_examples({LLAMA_EXAMPLE_DEBUG})); + add_opt(common_arg( + {"--logits-output-dir"}, "PATH", + string_format("directory for saving logits output files (default: %s)", params.logits_output_dir.c_str()), + [](common_params & params, const std::string & value) { + params.logits_output_dir = value; + } + ).set_examples({LLAMA_EXAMPLE_DEBUG})); + add_opt(common_arg( + {"--tensor-filter"}, "REGEX", + "filter tensor names for debug output (regex pattern, can be specified multiple times)", + [](common_params & params, const std::string & value) { + params.tensor_filter.push_back(value); + } + ).set_examples({LLAMA_EXAMPLE_DEBUG})); // presets add_opt(common_arg( diff --git a/common/common.h b/common/common.h index daea6ded5b..d6fd0d37a9 100644 --- a/common/common.h +++ b/common/common.h @@ -80,6 +80,7 @@ int32_t cpu_get_num_math(); // enum llama_example { + LLAMA_EXAMPLE_DEBUG, LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_COMPLETION, @@ -372,6 +373,11 @@ struct common_params { std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT std::string logits_file = ""; // file for saving *all* logits // NOLINT + // llama-debug specific options + std::string logits_output_dir = "data"; // directory for saving logits output files // NOLINT + bool save_logits = false; // whether to save logits to files // NOLINT + std::vector tensor_filter; // filter tensor names for debug output (regex) // NOLINT + std::vector in_files; // all input files std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector kv_overrides; diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 91797cf78a..a29dc707c3 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -15,6 +15,7 @@ llama_add_compile_flags() if (EMSCRIPTEN) else() add_subdirectory(batched) + add_subdirectory(debug) add_subdirectory(embedding) add_subdirectory(eval-callback) @@ -34,7 +35,6 @@ else() add_subdirectory(gen-docs) add_subdirectory(training) add_subdirectory(diffusion) - add_subdirectory(model-conversion) if (NOT GGML_BACKEND_DL) add_subdirectory(convert-llama2c-to-ggml) # these examples use the backends directly and cannot be built with dynamic loading diff --git a/examples/model-conversion/CMakeLists.txt b/examples/debug/CMakeLists.txt similarity index 73% rename from examples/model-conversion/CMakeLists.txt rename to examples/debug/CMakeLists.txt index fc1746ce45..34593072be 100644 --- a/examples/model-conversion/CMakeLists.txt +++ b/examples/debug/CMakeLists.txt @@ -1,5 +1,5 @@ -set(TARGET llama-logits) -add_executable(${TARGET} logits.cpp) +set(TARGET llama-debug) +add_executable(${TARGET} debug.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/debug/README.md b/examples/debug/README.md new file mode 100644 index 0000000000..28e00c9342 --- /dev/null +++ b/examples/debug/README.md @@ -0,0 +1,54 @@ +# llama.cpp/examples/debug + +This is a utility intended to help debug a model by registering a callback that +logs GGML operations and tensor data. It can also store the generated logits or +embeddings as well as the prompt and token ids for comparision with the original +model. + +### Usage + +```shell +llama-debug \ + --hf-repo ggml-org/models \ + --hf-file phi-2/ggml-model-q4_0.gguf \ + --model phi-2-q4_0.gguf \ + --prompt hello \ + --save-logits \ + --verbose +``` +The tensor data is logged as debug and required the --verbose flag. The reason +for this is that while useful for a model with many layers there can be a lot of +output. You can filter the tensor names using the `--tensor-filter` option. + +A recommended approach is to first run without `--verbose` and see if the +generated logits/embeddings are close to the original model. If they are not, +then it might be required to inspect tensor by tensor and in that case it is +useful to enable the `--verbose` flag along with `--tensor-filter` to focus on +specific tensors. + +### Options +This example supports all standard `llama.cpp` options and also accepts the +following options: +```console +$ llama-debug --help +... + +----- example-specific params ----- + +--save-logits save final logits to files for verification (default: false) +--logits-output-dir PATH directory for saving logits output files (default: data) +--tensor-filter REGEX filter tensor names for debug output (regex pattern, can be specified multiple times) +``` + +### Output Files + +When `--save-logits` is enabled, the following files are created in the output +directory: + +* `llamacpp-[-embeddings].bin` - Binary output (logits or embeddings) +* `llamacpp-[-embeddings].txt` - Text output (logits or embeddings, one per line) +* `llamacpp-[-embeddings]-prompt.txt` - Prompt text and token IDs +* `llamacpp-[-embeddings]-tokens.bin` - Binary token IDs for programmatic comparison + +These files can be compared against the original model's output to verify the +converted model. diff --git a/examples/debug/debug.cpp b/examples/debug/debug.cpp new file mode 100644 index 0000000000..9bc5d0abfd --- /dev/null +++ b/examples/debug/debug.cpp @@ -0,0 +1,421 @@ +#include "arg.h" +#include "common.h" +#include "log.h" +#include "llama.h" +#include "ggml.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +static void print_usage(int, char ** argv) { + const std::string usage_template = R"( + example usage: + + Print tensors: + + {prog} -m model.gguf -p "Hello my name is" --verbose + + The tensors to be printed can be filtered with --tensor-filter option. + + Save logits/embeddings: + + {prog} -m model.gguf -p "Hello my name is" --save-logits + + Add --embedding to save embeddings)" "\n"; + + // Fix the source code indentation above that is introduced by the raw string literal. + std::string usage = std::regex_replace(usage_template, std::regex("\\n {8}"), "\n"); + usage = std::regex_replace(usage, std::regex("\\{prog\\}"), argv[0]); + LOG("%s\n", usage.c_str()); +} + +static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data); + +struct callback_data { + std::vector data; + std::vector tensor_filters; + + callback_data() = default; + + callback_data(common_params & params, const std::vector & filter_patterns) { + for (const auto & pattern : filter_patterns) { + try { + std::string anchored_pattern = "^" + pattern; + tensor_filters.emplace_back(anchored_pattern, std::regex::optimize); + } catch (const std::regex_error & e) { + throw std::runtime_error("Invalid regex pattern '" + pattern + "': " + e.what()); + } + } + params.cb_eval = ggml_debug; + params.cb_eval_user_data = this; + } +}; + +struct output_data { + float * data_ptr = nullptr; + int data_size = 0; + std::string type_suffix; + std::vector storage; + std::string prompt; + std::vector tokens; + + output_data(llama_context * ctx, const llama_model * model, const common_params & params) { + const llama_vocab * vocab = llama_model_get_vocab(model); + const bool add_bos = llama_vocab_get_add_bos(vocab); + + tokens = common_tokenize(ctx, params.prompt, add_bos); + prompt = params.prompt; + + if (params.embedding) { + const int n_embd = llama_model_n_embd_out(model); + const bool pooling_enabled = llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE; + const int n_embd_count = pooling_enabled ? 1 : tokens.size(); + const int n_embeddings = n_embd * n_embd_count; + + float * embeddings; + if (pooling_enabled) { + embeddings = llama_get_embeddings_seq(ctx, 0); + storage.resize(n_embeddings); + common_embd_normalize(embeddings, storage.data(), n_embeddings, params.embd_normalize); + embeddings = storage.data(); + } else { + embeddings = llama_get_embeddings(ctx); + } + + data_ptr = embeddings; + data_size = n_embeddings; + type_suffix = "-embeddings"; + } else { + const float * logits = llama_get_logits_ith(ctx, tokens.size() - 1); + const int n_logits = llama_vocab_n_tokens(vocab); + + data_ptr = const_cast(logits); + data_size = n_logits; + type_suffix = ""; + } + } +}; + +static std::string ggml_ne_string(const ggml_tensor * t) { + std::string str; + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + str += std::to_string(t->ne[i]); + if (i + 1 < GGML_MAX_DIMS) { + str += ", "; + } + } + return str; +} + +static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) { + union { + float f; + uint32_t i; + } u; + u.i = (uint32_t)h.bits << 16; + return u.f; +} + +static float ggml_get_float_value(const uint8_t * data, ggml_type type, + const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) { + size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; + switch (type) { + case GGML_TYPE_F16: + return ggml_fp16_to_fp32(*(const ggml_fp16_t *) &data[i]); + case GGML_TYPE_F32: + return *(const float *) &data[i]; + case GGML_TYPE_I64: + return (float) *(const int64_t *) &data[i]; + case GGML_TYPE_I32: + return (float) *(const int32_t *) &data[i]; + case GGML_TYPE_I16: + return (float) *(const int16_t *) &data[i]; + case GGML_TYPE_I8: + return (float) *(const int8_t *) &data[i]; + case GGML_TYPE_BF16: + return ggml_compute_bf16_to_fp32(*(const ggml_bf16_t *) &data[i]); + default: + GGML_ABORT("fatal error"); + } +} + +static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) { + GGML_ASSERT(n > 0); + float sum = 0; + float sum_sq = 0.0; + for (int64_t i3 = 0; i3 < ne[3]; i3++) { + for (int64_t i2 = 0; i2 < ne[2]; i2++) { + for (int64_t i1 = 0; i1 < ne[1]; i1++) { + for (int64_t i0 = 0; i0 < ne[0]; i0++) { + const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3); + sum += v; + sum_sq += v * v; + } + } + } + } + for (int64_t i3 = 0; i3 < ne[3]; i3++) { + LOG_DBG(" [\n"); + for (int64_t i2 = 0; i2 < ne[2]; i2++) { + if (i2 == n && ne[2] > 2*n) { + LOG_DBG(" ..., \n"); + i2 = ne[2] - n; + } + LOG_DBG(" [\n"); + for (int64_t i1 = 0; i1 < ne[1]; i1++) { + if (i1 == n && ne[1] > 2*n) { + LOG_DBG(" ..., \n"); + i1 = ne[1] - n; + } + LOG_DBG(" ["); + for (int64_t i0 = 0; i0 < ne[0]; i0++) { + if (i0 == n && ne[0] > 2*n) { + LOG_DBG("..., "); + i0 = ne[0] - n; + } + const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3); + LOG_DBG("%12.4f", v); + if (i0 < ne[0] - 1) { + LOG_DBG(", "); + } + } + LOG_DBG("],\n"); + } + LOG_DBG(" ],\n"); + } + LOG_DBG(" ]\n"); + LOG_DBG(" sum = %f\n", sum); + LOG_DBG(" sum_sq = %f\n", sum_sq); + } + + if (std::isnan(sum)) { + LOG_ERR("encountered NaN - aborting\n"); + exit(0); + } +} + +/** + * GGML operations callback during the graph execution. + * + * @param t current tensor + * @param ask when ask is true, the scheduler wants to know if we are interested in data from this tensor + * if we return true, a follow-up call will be made with ask=false in which we can do the actual collection. + * see ggml_backend_sched_eval_callback + * @param user_data user data to pass at each call back + * @return true to receive data or continue the graph, false otherwise + */ +static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { + auto * cb_data = (callback_data *) user_data; + + const struct ggml_tensor * src0 = t->src[0]; + const struct ggml_tensor * src1 = t->src[1]; + + if (ask) { + return true; // Always retrieve data + } + + bool matches_filter = cb_data->tensor_filters.empty(); + + if (!matches_filter) { + for (const auto & filter : cb_data->tensor_filters) { + if (std::regex_search(t->name, filter)) { + matches_filter = true; + break; + } + } + } + + char src1_str[128] = {0}; + if (src1) { + snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str()); + } + + if (matches_filter) { + LOG_DBG("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, + t->name, + ggml_type_name(t->type), + ggml_op_desc(t), + src0->name, + ggml_ne_string(src0).c_str(), + src1 ? src1_str : "", + ggml_ne_string(t).c_str()); + } + + const bool is_host = ggml_backend_buffer_is_host(t->buffer); + + if (!is_host) { + auto n_bytes = ggml_nbytes(t); + cb_data->data.resize(n_bytes); + ggml_backend_tensor_get(t, cb_data->data.data(), 0, n_bytes); + } + + if (!ggml_is_quantized(t->type) && matches_filter) { + uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data(); + ggml_print_tensor(data, t->type, t->ne, t->nb, 3); + } + + return true; +} + + +static void save_output_data(const output_data & output, const std::string & model_name, const std::string & output_dir) { + std::filesystem::create_directory(output_dir); + auto base_path = std::filesystem::path{output_dir} / ("llamacpp-" + model_name + output.type_suffix); + + // Save logits/embeddings to binary file. + { + std::filesystem::path filepath{base_path.string() + ".bin"}; + std::ofstream file{filepath, std::ios::binary}; + if (!file) { + throw std::runtime_error("failed to open binary output file: " + filepath.string()); + } + file.write(reinterpret_cast(output.data_ptr), output.data_size * sizeof(float)); + LOG("Data saved to %s\n", filepath.c_str()); + } + + // Save logits/embeddings to text file. + { + std::filesystem::path filepath{base_path.string() + ".txt"}; + std::ofstream file{filepath}; + if (!file) { + throw std::runtime_error("failed to open text output file: " + filepath.string()); + } + for (int i = 0; i < output.data_size; i++) { + file << i << ": " << output.data_ptr[i] << '\n'; + } + LOG("Data saved to %s\n", filepath.c_str()); + } + + // Save prompt and tokens to text file. + { + std::filesystem::path filepath{base_path.string() + "-prompt.txt"}; + std::ofstream file{filepath}; + if (!file) { + throw std::runtime_error("failed to open prompt output file: " + filepath.string()); + } + + file << "prompt: " << output.prompt << '\n'; + file << "n_tokens: " << output.tokens.size() << '\n'; + + file << "token ids: "; + for (size_t i = 0; i < output.tokens.size(); i++) { + file << output.tokens[i]; + if (i + 1 < output.tokens.size()) { + file << ", "; + } + } + file << '\n'; + LOG("Prompt saved to %s\n", filepath.c_str()); + } + + // Save token ids to binary file. + { + std::filesystem::path filepath{base_path.string() + "-tokens.bin"}; + std::ofstream file{filepath, std::ios::binary}; + if (!file) { + throw std::runtime_error("failed to open tokens binary file: " + filepath.string()); + } + file.write(reinterpret_cast(output.tokens.data()), output.tokens.size() * sizeof(llama_token)); + LOG("Tokens saved to %s\n", filepath.c_str()); + } + +} + +static void print_tokenized_prompt(llama_context * ctx, const std::vector & tokens, const std::string & prompt) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + LOG("Model add_bos: %s\n", llama_vocab_get_add_bos(vocab) ? "true" : "false"); + LOG("Input prompt: \"%s\"\n", prompt.c_str()); + LOG("Token ids (%zu):\n", tokens.size()); + + for (auto id : tokens) { + std::string piece(128, '\0'); + int n = llama_token_to_piece(vocab, id, piece.data(), piece.size(), 0, true); + if (n < 0) { + LOG_ERR("failed to convert token %d to piece\n", id); + continue; + } + piece.resize(n); + LOG("%s(%d) ", piece.c_str(), id); + } + LOG("\n"); +} + +static bool run(llama_context * ctx, const common_params & params) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const bool add_bos = llama_vocab_get_add_bos(vocab); + + std::vector tokens = common_tokenize(ctx, params.prompt, add_bos); + + if (tokens.empty()) { + LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__); + return false; + } + + if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { + LOG_ERR("%s : failed to eval\n", __func__); + return false; + } + + print_tokenized_prompt(ctx, tokens, params.prompt); + + if (params.save_logits) { + output_data output {ctx, model, params}; + std::filesystem::path model_path{params.model.path}; + std::string model_name{model_path.stem().string()}; + save_output_data(output, model_name, params.logits_output_dir); + } + + return true; +} + +int main(int argc, char ** argv) { + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DEBUG, print_usage)) { + return 1; + } + + common_init(); + + llama_backend_init(); + llama_numa_init(params.numa); + + callback_data cb_data(params, params.tensor_filter); + + auto llama_init = common_init_from_params(params); + + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); + + if (model == nullptr || ctx == nullptr) { + LOG_ERR("%s : failed to init\n", __func__); + return 1; + } + + { + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + LOG_INF("\n"); + } + + if (!run(ctx, params)) { + return 1; + } + + LOG("\n"); + llama_perf_context_print(ctx); + + llama_backend_free(); + + return 0; +} diff --git a/examples/model-conversion/logits.cpp b/examples/model-conversion/logits.cpp deleted file mode 100644 index f71f772ab1..0000000000 --- a/examples/model-conversion/logits.cpp +++ /dev/null @@ -1,268 +0,0 @@ -#include "llama.h" -#include "common.h" - - -#include -#include -#include -#include -#include -#include - -static void print_usage(int, char ** argv) { - printf("\nexample usage:\n"); - printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [-pooling] [-embd-norm ] [prompt]\n", argv[0]); - printf("\n"); - printf(" -embd-norm: normalization type for pooled embeddings (default: 2)\n"); - printf(" -1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm\n"); - printf("\n"); -} - -int main(int argc, char ** argv) { - std::string model_path; - std::string prompt = "Hello, my name is"; - int ngl = 0; - bool embedding_mode = false; - bool pooling_enabled = false; - int32_t embd_norm = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) - - { - int i = 1; - for (; i < argc; i++) { - if (strcmp(argv[i], "-m") == 0) { - if (i + 1 < argc) { - model_path = argv[++i]; - } else { - print_usage(argc, argv); - return 1; - } - } else if (strcmp(argv[i], "-ngl") == 0) { - if (i + 1 < argc) { - try { - ngl = std::stoi(argv[++i]); - } catch (...) { - print_usage(argc, argv); - return 1; - } - } else { - print_usage(argc, argv); - return 1; - } - } else if (strcmp(argv[i], "-embd-mode") == 0) { - embedding_mode = true; - } else if (strcmp(argv[i], "-pooling") == 0) { - pooling_enabled = true; - } else if (strcmp(argv[i], "-embd-norm") == 0) { - if (i + 1 < argc) { - try { - embd_norm = std::stoi(argv[++i]); - } catch (...) { - print_usage(argc, argv); - return 1; - } - } else { - print_usage(argc, argv); - return 1; - } - } else { - // prompt starts here - break; - } - } - - if (model_path.empty()) { - print_usage(argc, argv); - return 1; - } - - if (i < argc) { - prompt = argv[i++]; - for (; i < argc; i++) { - prompt += " "; - prompt += argv[i]; - } - } - } - - ggml_backend_load_all(); - llama_model_params model_params = llama_model_default_params(); - model_params.n_gpu_layers = ngl; - - llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params); - - if (model == NULL) { - fprintf(stderr , "%s: error: unable to load model\n" , __func__); - return 1; - } - - // Extract basename from model_path - const char * basename = strrchr(model_path.c_str(), '/'); - basename = (basename == NULL) ? model_path.c_str() : basename + 1; - - char model_name[256]; - strncpy(model_name, basename, 255); - model_name[255] = '\0'; - - char * dot = strrchr(model_name, '.'); - if (dot != NULL && strcmp(dot, ".gguf") == 0) { - *dot = '\0'; - } - printf("Model name: %s\n", model_name); - - const llama_vocab * vocab = llama_model_get_vocab(model); - const int n_prompt = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true); - - std::vector prompt_tokens(n_prompt); - if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) { - fprintf(stderr, "%s: error: failed to tokenize the prompt\n", __func__); - return 1; - } - - llama_context_params ctx_params = llama_context_default_params(); - ctx_params.n_ctx = n_prompt; - ctx_params.n_batch = n_prompt; - ctx_params.no_perf = false; - if (embedding_mode) { - ctx_params.embeddings = true; - ctx_params.pooling_type = pooling_enabled ? LLAMA_POOLING_TYPE_MEAN : LLAMA_POOLING_TYPE_NONE; - ctx_params.n_ubatch = ctx_params.n_batch; - } - - llama_context * ctx = llama_init_from_model(model, ctx_params); - if (ctx == NULL) { - fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); - return 1; - } - - printf("Input prompt: \"%s\"\n", prompt.c_str()); - printf("Tokenized prompt (%d tokens): ", n_prompt); - for (auto id : prompt_tokens) { - char buf[128]; - int n = llama_token_to_piece(vocab, id, buf, sizeof(buf), 0, true); - if (n < 0) { - fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__); - return 1; - } - std::string s(buf, n); - printf("%s (%d)", s.c_str(), id); - } - printf("\n"); - - llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); - - if (llama_decode(ctx, batch)) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return 1; - } - - float * data_ptr; - int data_size; - const char * type; - std::vector embd_out; - - if (embedding_mode) { - const int n_embd_out = llama_model_n_embd_out(model); - const int n_embd_count = pooling_enabled ? 1 : batch.n_tokens; - const int n_embeddings = n_embd_out * n_embd_count; - float * embeddings; - type = "-embeddings"; - - if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) { - embeddings = llama_get_embeddings_seq(ctx, 0); - embd_out.resize(n_embeddings); - printf("Normalizing embeddings using norm: %d\n", embd_norm); - common_embd_normalize(embeddings, embd_out.data(), n_embeddings, embd_norm); - embeddings = embd_out.data(); - } else { - embeddings = llama_get_embeddings(ctx); - } - - printf("Embedding dimension: %d\n", n_embd_out); - printf("\n"); - - // Print embeddings in the specified format - for (int j = 0; j < n_embd_count; j++) { - printf("embedding %d: ", j); - - // Print first 3 values - for (int i = 0; i < 3 && i < n_embd_out; i++) { - printf("%9.6f ", embeddings[j * n_embd_out + i]); - } - - printf(" ... "); - - // Print last 3 values - for (int i = n_embd_out - 3; i < n_embd_out; i++) { - if (i >= 0) { - printf("%9.6f ", embeddings[j * n_embd_out + i]); - } - } - - printf("\n"); - } - printf("\n"); - - printf("Embeddings size: %d\n", n_embeddings); - - data_ptr = embeddings; - data_size = n_embeddings; - } else { - float * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); - const int n_logits = llama_vocab_n_tokens(vocab); - type = ""; - printf("Vocab size: %d\n", n_logits); - - data_ptr = logits; - data_size = n_logits; - } - - std::filesystem::create_directory("data"); - - // Save data to binary file - char bin_filename[512]; - snprintf(bin_filename, sizeof(bin_filename), "data/llamacpp-%s%s.bin", model_name, type); - printf("Saving data to %s\n", bin_filename); - - FILE * f = fopen(bin_filename, "wb"); - if (f == NULL) { - fprintf(stderr, "%s: error: failed to open binary output file\n", __func__); - return 1; - } - fwrite(data_ptr, sizeof(float), data_size, f); - fclose(f); - - // Also save as text for debugging - char txt_filename[512]; - snprintf(txt_filename, sizeof(txt_filename), "data/llamacpp-%s%s.txt", model_name, type); - f = fopen(txt_filename, "w"); - if (f == NULL) { - fprintf(stderr, "%s: error: failed to open text output file\n", __func__); - return 1; - } - for (int i = 0; i < data_size; i++) { - fprintf(f, "%d: %.6f\n", i, data_ptr[i]); - } - fclose(f); - - if (!embedding_mode) { - printf("First 10 logits: "); - for (int i = 0; i < 10 && i < data_size; i++) { - printf("%.6f ", data_ptr[i]); - } - printf("\n"); - - printf("Last 10 logits: "); - for (int i = data_size - 10; i < data_size; i++) { - if (i >= 0) printf("%.6f ", data_ptr[i]); - } - printf("\n\n"); - } - - printf("Data saved to %s\n", bin_filename); - printf("Data saved to %s\n", txt_filename); - - llama_free(ctx); - llama_model_free(model); - - return 0; -} diff --git a/examples/model-conversion/scripts/causal/compare-logits.py b/examples/model-conversion/scripts/causal/compare-logits.py index 894302c69e..1a933207d5 100755 --- a/examples/model-conversion/scripts/causal/compare-logits.py +++ b/examples/model-conversion/scripts/causal/compare-logits.py @@ -6,7 +6,7 @@ from pathlib import Path # Add utils directory to path for direct script execution sys.path.insert(0, str(Path(__file__).parent.parent / "utils")) -from common import get_model_name_from_env_path # type: ignore[import-not-found] +from common import get_model_name_from_env_path, compare_tokens # type: ignore[import-not-found] def quick_logits_check(pytorch_file, llamacpp_file): """Lightweight sanity check before NMSE""" @@ -58,6 +58,13 @@ def main(): print("Checked all required files were found. Proceeding...\n") + # Verify tokens as they are a prerequisite for logits comparison. + print("šŸ” Token Comparison Check") + print("=" * 40) + if not compare_tokens(f"pytorch-{model_name}", f"llamacpp-{llamacpp_model_name}"): + print("\nāŒ Token mismatch detected") + sys.exit(1) + print() print("šŸ” GGML Model Validation for model ", model_name) print("=" * 40) diff --git a/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py b/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py index 55ad821385..4ab778fbc7 100755 --- a/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py +++ b/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py @@ -67,7 +67,7 @@ with torch.no_grad(): last_hidden_states = outputs.hidden_states[-1] # Get embeddings for all tokens - token_embeddings = last_hidden_states[0].cpu().numpy() # Remove batch dimension + token_embeddings = last_hidden_states[0].float().cpu().numpy() # Remove batch dimension print(f"Hidden states shape: {last_hidden_states.shape}") print(f"Token embeddings shape: {token_embeddings.shape}") diff --git a/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh b/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh index fa16a02c65..3cce3fc94d 100755 --- a/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh +++ b/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh @@ -13,6 +13,6 @@ if [ -z "$CONVERTED_MODEL" ]; then exit 1 fi -cmake --build ../../build --target llama-logits -j8 +cmake --build ../../build --target llama-debug -j8 -../../build/bin/llama-logits -m $CONVERTED_MODEL -embd-mode "Hello world today" +../../build/bin/llama-debug -m $CONVERTED_MODEL --embedding -p "Hello world today" --save-logits diff --git a/examples/model-conversion/scripts/causal/run-converted-model.sh b/examples/model-conversion/scripts/causal/run-converted-model.sh index 529e9987b0..b6c3d38662 100755 --- a/examples/model-conversion/scripts/causal/run-converted-model.sh +++ b/examples/model-conversion/scripts/causal/run-converted-model.sh @@ -21,6 +21,6 @@ fi echo $CONVERTED_MODEL echo $MODEL_TESTING_PROMPT -cmake --build ../../build --target llama-logits -j8 +cmake --build ../../build --target llama-debug -j8 -../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT" +../../build/bin/llama-debug -m "$CONVERTED_MODEL" -p "$MODEL_TESTING_PROMPT" --save-logits diff --git a/examples/model-conversion/scripts/causal/run-org-model.py b/examples/model-conversion/scripts/causal/run-org-model.py index b12173a1fb..215f1a9ee0 100755 --- a/examples/model-conversion/scripts/causal/run-org-model.py +++ b/examples/model-conversion/scripts/causal/run-org-model.py @@ -7,12 +7,11 @@ import importlib import torch import numpy as np -from pathlib import Path from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig # Add parent directory to path for imports sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) -from utils.common import debug_hook +from utils.common import debug_hook, save_output_data def parse_arguments(): parser = argparse.ArgumentParser(description="Process model with specified path") @@ -126,6 +125,7 @@ def main(): device = next(model.parameters()).device prompt = get_prompt(args) input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) + token_ids = input_ids[0].cpu().tolist() print(f"Input tokens: {input_ids}") print(f"Input text: {repr(prompt)}") @@ -151,19 +151,6 @@ def main(): print(f"Last token logits shape: {last_logits.shape}") print(f"Vocab size: {len(last_logits)}") - data_dir = Path("data") - data_dir.mkdir(exist_ok=True) - bin_filename = data_dir / f"pytorch-{model_name}.bin" - txt_filename = data_dir / f"pytorch-{model_name}.txt" - - # Save to file for comparison - last_logits.astype(np.float32).tofile(bin_filename) - - # Also save as text file for easy inspection - with open(txt_filename, "w") as f: - for i, logit in enumerate(last_logits): - f.write(f"{i}: {logit:.6f}\n") - # Print some sample logits for quick verification print(f"First 10 logits: {last_logits[:10]}") print(f"Last 10 logits: {last_logits[-10:]}") @@ -175,8 +162,7 @@ def main(): token = tokenizer.decode([idx]) print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}") - print(f"Saved bin logits to: {bin_filename}") - print(f"Saved txt logist to: {txt_filename}") + save_output_data(last_logits, token_ids, prompt, model_name) if __name__ == "__main__": main() diff --git a/examples/model-conversion/scripts/embedding/run-converted-model.sh b/examples/model-conversion/scripts/embedding/run-converted-model.sh index 0f490e6c3b..5d264b0663 100755 --- a/examples/model-conversion/scripts/embedding/run-converted-model.sh +++ b/examples/model-conversion/scripts/embedding/run-converted-model.sh @@ -50,10 +50,9 @@ fi echo $CONVERTED_MODEL -cmake --build ../../build --target llama-logits -j8 -# TODO: update logits.cpp to accept a --file/-f option for the prompt +cmake --build ../../build --target llama-debug -j8 if [ -n "$USE_POOLING" ]; then - ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode -pooling "$PROMPT" + ../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding --pooling mean -p "$PROMPT" --save-logits else - ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT" + ../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding --pooling none -p "$PROMPT" --save-logits fi diff --git a/examples/model-conversion/scripts/embedding/run-original-model.py b/examples/model-conversion/scripts/embedding/run-original-model.py index 774e5638f7..0802cbcf4a 100755 --- a/examples/model-conversion/scripts/embedding/run-original-model.py +++ b/examples/model-conversion/scripts/embedding/run-original-model.py @@ -3,13 +3,15 @@ import argparse import os import sys -import numpy as np import importlib -from pathlib import Path from transformers import AutoTokenizer, AutoConfig, AutoModel import torch +# Add parent directory to path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) +from utils.common import save_output_data + def parse_arguments(): parser = argparse.ArgumentParser(description='Run original embedding model') @@ -169,6 +171,7 @@ def main(): return_tensors="pt" ) tokens = encoded['input_ids'][0] + token_ids = tokens.cpu().tolist() token_strings = tokenizer.convert_ids_to_tokens(tokens) for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)): print(f"{token_id:6d} -> '{token_str}'") @@ -185,6 +188,7 @@ def main(): ) tokens = encoded['input_ids'][0] + token_ids = tokens.cpu().tolist() token_strings = tokenizer.convert_ids_to_tokens(tokens) for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)): print(f"{token_id:6d} -> '{token_str}'") @@ -228,24 +232,11 @@ def main(): print() - data_dir = Path("data") - data_dir.mkdir(exist_ok=True) - bin_filename = data_dir / f"pytorch-{model_name}-embeddings.bin" - txt_filename = data_dir / f"pytorch-{model_name}-embeddings.txt" - flattened_embeddings = all_embeddings.flatten() - flattened_embeddings.astype(np.float32).tofile(bin_filename) - - with open(txt_filename, "w") as f: - idx = 0 - for j in range(n_embd_count): - for value in all_embeddings[j]: - f.write(f"{idx}: {value:.6f}\n") - idx += 1 print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} embeddings Ɨ {n_embd} dimensions)") print("") - print(f"Saved bin embeddings to: {bin_filename}") - print(f"Saved txt embeddings to: {txt_filename}") + + save_output_data(flattened_embeddings, token_ids, prompt_text, model_name, type_suffix="-embeddings") if __name__ == "__main__": diff --git a/examples/model-conversion/scripts/utils/common.py b/examples/model-conversion/scripts/utils/common.py index 7595d0410e..71761127bb 100644 --- a/examples/model-conversion/scripts/utils/common.py +++ b/examples/model-conversion/scripts/utils/common.py @@ -3,6 +3,8 @@ import os import sys import torch +import numpy as np +from pathlib import Path def get_model_name_from_env_path(env_path_name): @@ -148,3 +150,96 @@ def setup_rope_debug(model_module_path: str, function_name: str = "apply_rotary_ # Patch it setattr(module, function_name, debug_rope) print(f"RoPE debug patching applied to {model_module_path}.{function_name}") + + +def save_output_data(data, tokens, prompt, model_name, type_suffix="", output_dir="data"): + """ + Save output data (logits/embeddings), tokens, and prompt to files. + + Args: + data: numpy array of floats (logits or embeddings) + tokens: list or array of token IDs + prompt: string containing the input prompt + model_name: name of the model + type_suffix: optional suffix like "-embeddings" (default: "") + output_dir: directory to save files (default: "data") + + Creates the following files in output_dir: + - pytorch-{model_name}{type_suffix}.bin + - pytorch-{model_name}{type_suffix}.txt + - pytorch-{model_name}{type_suffix}-prompt.txt + - pytorch-{model_name}{type_suffix}-tokens.bin + """ + data_dir = Path(output_dir) + data_dir.mkdir(exist_ok=True) + base_path = data_dir / f"pytorch-{model_name}{type_suffix}" + + # Convert and flatten logits/embeddings + data = data.cpu().numpy() if isinstance(data, torch.Tensor) else np.asarray(data) + data = data.flatten() if data.ndim > 1 else data + + # Save logits/embedding files + data.astype(np.float32).tofile(f"{base_path}.bin") + print(f"Data saved to {base_path}.bin") + + with open(f"{base_path}.txt", "w") as f: + f.writelines(f"{i}: {value:.6f}\n" for i, value in enumerate(data)) + print(f"Data saved to {base_path}.txt") + + # Convert and flatten tokens + tokens = tokens.cpu().numpy() if isinstance(tokens, torch.Tensor) else np.asarray(tokens) + tokens = tokens.flatten() if tokens.ndim > 1 else tokens + + # Save token binary file + tokens.astype(np.int32).tofile(f"{base_path}-tokens.bin") + print(f"Tokens saved to {base_path}-tokens.bin") + + # Save prompt file + with open(f"{base_path}-prompt.txt", "w") as f: + f.write(f"prompt: {prompt}\n") + f.write(f"n_tokens: {len(tokens)}\n") + f.write(f"token ids: {', '.join(str(int(tid)) for tid in tokens)}\n") + print(f"Prompt saved to {base_path}-prompt.txt") + + +def compare_tokens(original, converted, type_suffix="", output_dir="data"): + data_dir = Path(output_dir) + + # Read tokens from both models + tokens1_file = data_dir / f"{original}{type_suffix}-tokens.bin" + tokens2_file = data_dir / f"{converted}{type_suffix}-tokens.bin" + + if not tokens1_file.exists(): + print(f"Error: Token file not found: {tokens1_file}") + return False + + if not tokens2_file.exists(): + print(f"Error: Token file not found: {tokens2_file}") + return False + + tokens1 = np.fromfile(tokens1_file, dtype=np.int32) + tokens2 = np.fromfile(tokens2_file, dtype=np.int32) + + print(f"\nComparing tokens between:") + print(f" Original : {original} ({len(tokens1)} tokens)") + print(f" Converted: {converted} ({len(tokens2)} tokens)") + + if len(tokens1) != len(tokens2): + print(f"\nāŒ Token count mismatch: {len(tokens1)} vs {len(tokens2)}") + return False + + if np.array_equal(tokens1, tokens2): + print(f"\nāœ… All {len(tokens1)} tokens match!") + return True + + mismatches = np.where(tokens1 != tokens2)[0] + print(f"\nāŒ Found {len(mismatches)} mismatched tokens:") + + num_to_show = min(len(mismatches), 10) + for idx in mismatches[:num_to_show]: + print(f" Position {idx}: {tokens1[idx]} vs {tokens2[idx]}") + + if len(mismatches) > num_to_show: + print(f" ... and {len(mismatches) - num_to_show} more mismatches") + + return False diff --git a/examples/model-conversion/scripts/utils/compare_tokens.py b/examples/model-conversion/scripts/utils/compare_tokens.py new file mode 100755 index 0000000000..a286cb5683 --- /dev/null +++ b/examples/model-conversion/scripts/utils/compare_tokens.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 + +import argparse +import sys +from common import compare_tokens # type: ignore + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description='Compare tokens between two models', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + %(prog)s pytorch-gemma-3-270m-it llamacpp-gemma-3-270m-it-bf16 + """ + ) + parser.add_argument( + 'original', + help='Original model name' + ) + parser.add_argument( + 'converted', + help='Converted model name' + ) + parser.add_argument( + '-s', '--suffix', + default='', + help='Type suffix (e.g., "-embeddings")' + ) + parser.add_argument( + '-d', '--data-dir', + default='data', + help='Directory containing token files (default: data)' + ) + parser.add_argument( + '-v', '--verbose', + action='store_true', + help='Print prompts from both models' + ) + return parser.parse_args() + + +def main(): + args = parse_arguments() + + if args.verbose: + from pathlib import Path + data_dir = Path(args.data_dir) + + prompt1_file = data_dir / f"{args.original}{args.suffix}-prompt.txt" + prompt2_file = data_dir / f"{args.converted}{args.suffix}-prompt.txt" + + if prompt1_file.exists(): + print(f"\nOriginal model prompt ({args.original}):") + print(f" {prompt1_file.read_text().strip()}") + + if prompt2_file.exists(): + print(f"\nConverted model prompt ({args.converted}):") + print(f" {prompt2_file.read_text().strip()}") + + print() + + result = compare_tokens( + args.original, + args.converted, + type_suffix=args.suffix, + output_dir=args.data_dir + ) + + # Enable the script to be used in shell scripts so that they can check + # the exit code for success/failure. + sys.exit(0 if result else 1) + + +if __name__ == "__main__": + main() diff --git a/examples/model-conversion/scripts/utils/semantic_check.py b/examples/model-conversion/scripts/utils/semantic_check.py index e64c000497..38b03ce4d2 100644 --- a/examples/model-conversion/scripts/utils/semantic_check.py +++ b/examples/model-conversion/scripts/utils/semantic_check.py @@ -4,8 +4,10 @@ import numpy as np import argparse import os import importlib +from pathlib import Path from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModel +from common import compare_tokens # type: ignore[import-not-found] unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME') @@ -157,9 +159,25 @@ def main(): else: prompt = args.prompt + python_emb_path = Path(args.python_embeddings) + cpp_emb_path = Path(args.cpp_embeddings) + + # Extract base names (e.g., "pytorch-model-name-embeddings.bin" -> "pytorch-model-name") + python_model_name = python_emb_path.stem.replace("-embeddings", "") + cpp_model_name = cpp_emb_path.stem.replace("-embeddings", "") + print("Semantic Similarity Test Between Python and llama.cpp Embedding Models") print("=" * 70) + # First verify tokens match before comparing embeddings + print("\nšŸ” Token Comparison Check") + print("=" * 70) + data_dir = python_emb_path.parent + if not compare_tokens(python_model_name, cpp_model_name, type_suffix="-embeddings", output_dir=str(data_dir)): + print("\nāŒ Token mismatch detected") + exit(1) + print() + # Single prompt detailed comparison print(f"\nTesting with prompt: '{prompt}'") From 8c77a04cc723909eab5d3bc3ae14c82f4db1afc7 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Wed, 7 Jan 2026 10:13:17 +0000 Subject: [PATCH 02/27] vulkan: more mul mat optimizations (#18533) * q4_k * q5_k * q2_k * q4_1 * q5_1 * better buf index --- .../vulkan-shaders/dequant_funcs.glsl | 3 +- .../vulkan-shaders/mul_mm_funcs.glsl | 82 ++++++++++--------- .../vulkan-shaders/vulkan-shaders-gen.cpp | 4 +- 3 files changed, 47 insertions(+), 42 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 376944f1e2..7865a6bda7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -462,7 +462,8 @@ vec2 get_dm(uint ib, uint a_offset) { #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) vec2 get_dm(uint ib, uint a_offset) { - return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m)); + const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm); + return dm; } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 1a3531761a..ce7f2d699a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -47,7 +47,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin #endif #elif defined(DATA_A_Q4_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + 2 * row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 4; const uint iqs = idx & 0x03; @@ -63,16 +63,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw); #elif defined(DATA_A_Q4_1) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + 2 * row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 4; const uint iqs = idx & 0x03; - const float d = float(data_a_packed16[ib].d); - const float m = float(data_a_packed16[ib].m); - const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); - const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m; - const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m; + const vec2 dm = vec2(data_a_packed32[ib].dm); + const uint vui = data_a_packed32[ib].qs[iqs]; + const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y; + const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * dm.x + dm.y; buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw); @@ -80,7 +79,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw); #elif defined(DATA_A_Q5_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 8; const uint iqs = idx & 0x07; @@ -97,22 +96,26 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); #elif defined(DATA_A_Q5_1) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; - const uint ib = idx / 8; - const uint iqs = idx & 0x07; + const uint ib = idx / 4; + const uint iqs = idx & 0x03; - const float d = float(data_a_packed16[ib].d); - const float m = float(data_a_packed16[ib].m); - const uint uint_qh = data_a_packed16[ib].qh; - const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); - const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); + const vec2 dm = vec2(data_a_packed32[ib].dm); + const uint uint_qh = data_a_packed32[ib].qh; + const uvec2 qh0 = uvec2(((uint_qh >> 4*iqs) << 4) & 0x10, (uint_qh >> (4*iqs + 12)) & 0x10); + const uvec2 qh1 = uvec2(((uint_qh >> (4*iqs + 1)) << 4) & 0x10, (uint_qh >> (4*iqs + 13)) & 0x10); + const uvec2 qh2 = uvec2(((uint_qh >> (4*iqs + 2)) << 4) & 0x10, (uint_qh >> (4*iqs + 14)) & 0x10); + const uvec2 qh3 = uvec2(((uint_qh >> (4*iqs + 3)) << 4) & 0x10, (uint_qh >> (4*iqs + 15)) & 0x10); - const uint vui = uint(data_a_packed16[ib].qs[iqs]); - const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m; + const uint vui = data_a_packed32[ib].qs[iqs]; + const vec4 v0 = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * dm.x + dm.y; + const vec4 v1 = vec4(((vui >> 16) & 0xF) | qh2.x, ((vui >> 20) & 0xF) | qh2.y, ((vui >> 24) & 0xF) | qh3.x, ((vui >> 28) & 0xF) | qh3.y) * dm.x + dm.y; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz); - buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xz); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v1.xz); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v0.yw); + buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.yw); #elif defined(DATA_A_Q8_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -131,20 +134,21 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = (idx % 64) * 2; // 0,2,4..126 const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15 const uint scalesi = iqs / 8; // 0..15 const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 - const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi])); + const vec4 qs = vec4(unpack8((data_a_packed32[ib].qs[qsi / 2] >> qsshift) & 0x03030303)); const uint scales = data_a[ib].scales[scalesi]; const vec2 dm = vec2(data_a[ib].dm); - const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4); + const vec4 v = dm.x * float(scales & 0xF) * qs - dm.y * float(scales >> 4); - buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy); + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw); #elif defined(DATA_A_Q3_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -173,8 +177,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = (idx % 64) * 2; // 0,2,4..126 const uint n = iqs / 32; // 0,1,2,3 const uint b = (iqs % 32) / 16; // 0,1 @@ -200,16 +204,16 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const float d = loadd.x * sc; const float m = -loadd.y * mbyte; - const vec2 q = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F).xy); + const vec4 q = vec4(unpack8((data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F)); - buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m), - fma(d, q.y, m)); + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m)); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m)); #elif defined(DATA_A_Q5_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = (idx % 64) * 2; // 0,2,4..126 const uint n = iqs / 32; // 0,1,2,3 const uint b = (iqs % 32) / 16; // 0,1 @@ -236,12 +240,12 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const float d = loadd.x * sc; const float m = -loadd.y * mbyte; - const uint qs = (uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F; - const uint qh = ((uint(data_a_packed16[ib].qh[qhi / 2]) >> (iqs / 16)) & 0x0101) << 4; - const vec2 q = vec2(unpack8(qs | qh).xy); + const uint qs = (data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F; + const uint qh = ((data_a_packed32[ib].qh[qhi / 4] >> (iqs / 16)) & 0x01010101) << 4; + const vec4 q = vec4(unpack8(qs | qh)); - buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m), - fma(d, q.y, m)); + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m)); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m)); #elif defined(DATA_A_Q6_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -455,7 +459,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); #elif defined(DATA_A_IQ4_NL) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 8; const uint iqs = idx & 0x07; @@ -469,7 +473,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin kvalues_iq4nl[vui >> 12]); #elif defined(DATA_A_MXFP4) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 8; const uint iqs = (idx & 0x07) * 2; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 5b61ff9ca2..bbdbf9dcaa 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -552,9 +552,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c for (const auto& tname : type_names) { std::string load_vec_quant = "2"; - if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) + if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) load_vec_quant = "8"; - else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4")) + else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4")) load_vec_quant = "4"; if (tname == "bf16") { From 03023296cf63f4177f51db9126b16b06f0e0af98 Mon Sep 17 00:00:00 2001 From: virajwad <84867530+virajwad@users.noreply.github.com> Date: Wed, 7 Jan 2026 02:59:47 -0800 Subject: [PATCH 03/27] vulkan: Warptile tuning for Intel Xe2/Xe3 (#18178) * modify warptile tuning for xe3 * intel vendor check w/ coopmat support * fix back formatting * fix formatting change 2 * move intel check to chip specific tuning part * Change to support both windows and linux * modify m_warptile to l_warptile for intel * modify warptile tuning for bf16 matmuls to fix regression (m_warptile to l_warptile) * Code style changes * Code style changes (2) * Code style changes (3) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3c13777b8a..1f255b705e 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2996,6 +2996,10 @@ static void ggml_vk_load_shaders(vk_device& device) { if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) { m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; + } else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->coopmat_support && device->architecture == INTEL_XE2) { + // Xe2/Xe3 with coopmat enabled - warptile performance tuning + l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + l_warptile_mmq = { 512, 128, 128, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; } l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; @@ -3678,6 +3682,11 @@ static void ggml_vk_load_shaders(vk_device& device) { m_wg_denoms = { 64, 64, 1 }; s_wg_denoms = { 32, 32, 1 }; + if (device->vendor_id == VK_VENDOR_ID_INTEL && device->architecture == INTEL_XE2) { + // Xe2/Xe3 - bf16 warptile performance tuning + l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, 4, 4, 1, subgroup_size_8 }; + } + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); } @@ -5061,11 +5070,23 @@ static vk_device ggml_vk_get_device(size_t idx) { switch (device->vendor_id) { #ifndef GGML_VULKAN_RUN_TESTS case VK_VENDOR_ID_AMD: + device->mul_mat_l[i] = false; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = true; + device->mul_mat_id_l[i] = false; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = true; + break; case VK_VENDOR_ID_INTEL: - device->mul_mat_l[i] = false; + if (!device->coopmat_support || device->architecture != INTEL_XE2) { + device->mul_mat_l[i] = false; + device->mul_mat_id_l[i] = false; + } else { + device->mul_mat_l[i] = true; // if coopmat & XE2+, allow large matmul warptile config for Intel + device->mul_mat_id_l[i] = true; + } device->mul_mat_m[i] = true; device->mul_mat_s[i] = true; - device->mul_mat_id_l[i] = false; device->mul_mat_id_m[i] = true; device->mul_mat_id_s[i] = true; break; From ca4a8370bc1ebf267073cfa29067ebeff7ab8015 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 7 Jan 2026 05:03:32 -0600 Subject: [PATCH 04/27] vulkan: reject ops when a tensor is too large to allocate (#18646) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 29 +++++++++++++--------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 1f255b705e..d68735a040 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -14305,6 +14305,19 @@ static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const } static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + + // reject any tensors larger than the max buffer size + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (op->src[i] && ggml_nbytes(op->src[i]) > device->max_buffer_size) { + return false; + } + } + if (ggml_nbytes(op) > device->max_buffer_size) { + return false; + } + switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { @@ -14353,8 +14366,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_MUL_MAT_ID: { ggml_type src0_type = op->src[0]->type; - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - const vk_device& device = ggml_vk_get_device(ctx->device); if (op->op == GGML_OP_MUL_MAT_ID) { if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) { // If there's not enough shared memory for row_ids and the result tile, fallback to CPU @@ -14415,8 +14426,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } case GGML_OP_FLASH_ATTN_EXT: { - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - auto device = ggml_vk_get_device(ctx->device); bool coopmat2 = device->coopmat2; uint32_t HSK = op->src[1]->ne[0]; uint32_t HSV = op->src[2]->ne[0]; @@ -14638,8 +14647,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { return false; } - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - auto device = ggml_vk_get_device(ctx->device); // pipeline_argsort_large_f32 requires vulkan memory model. if (device->vulkan_memory_model) { return true; @@ -14652,8 +14659,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { return false; } - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - auto device = ggml_vk_get_device(ctx->device); // We could potentially support larger, using argsort to sort the // whole thing. Not clear if this is needed. uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1; @@ -14700,8 +14705,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_CUMSUM: { - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - auto device = ggml_vk_get_device(ctx->device); if (device->subgroup_arithmetic && device->subgroup_require_full_support) { return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); } @@ -14709,9 +14712,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } case GGML_OP_SOLVE_TRI: { - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - const vk_device& device = ggml_vk_get_device(ctx->device); - if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) { return false; } @@ -14776,9 +14776,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - const vk_device& device = ggml_vk_get_device(ctx->device); - const uint32_t SPLIT_H = 16; size_t stateC_size = SPLIT_H * d_state * sizeof(float); From 9dfa8ee950b077b2d8a49caaa144dcc6bbc55305 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Wed, 7 Jan 2026 13:07:08 +0100 Subject: [PATCH 05/27] ci : run cann build unconditionally [no ci] (#18659) --- .github/workflows/build.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1193779d0b..85601b3712 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1418,7 +1418,6 @@ jobs: echo "FIXME: test on devices" openEuler-latest-cmake-cann: - if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'Ascend NPU') }} defaults: run: shell: bash -el {0} From bb77764c2d024a6fecc5bdeb3618cb580ee15041 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 7 Jan 2026 13:18:53 +0100 Subject: [PATCH 06/27] convert : clarify sentence-transformers-dense-modules help [no ci] (#18662) * convert : clarify sentence-transformers-dense-modules help [no ci] This commit updates this options help message which currently looks like this: ```console --sentence-transformers-dense-modules Whether to include sentence-transformers dense modules.It can be used for sentence-transformers models, like google/embeddinggemma-300mDefault these modules are not included. ``` --- convert_hf_to_gguf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index d9ee390b38..0a8bac0e2d 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -10974,8 +10974,8 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--sentence-transformers-dense-modules", action="store_true", - help=("Whether to include sentence-transformers dense modules." - "It can be used for sentence-transformers models, like google/embeddinggemma-300m" + help=("Whether to include sentence-transformers dense modules. " + "It can be used for sentence-transformers models, like google/embeddinggemma-300m. " "Default these modules are not included.") ) From 56426673cb950feaff28c466c7cf38ac4c165742 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 7 Jan 2026 15:16:20 +0200 Subject: [PATCH 07/27] scripts : add pr2wt.sh (#18644) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * scripts : add pr2wt.sh * script : shebang Co-authored-by: SigbjĆørn SkjƦret --------- Co-authored-by: SigbjĆørn SkjƦret --- .gitignore | 1 + scripts/pr2wt.sh | 65 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) create mode 100755 scripts/pr2wt.sh diff --git a/.gitignore b/.gitignore index 05eb578a82..bb122d6924 100644 --- a/.gitignore +++ b/.gitignore @@ -130,6 +130,7 @@ poetry.toml # Local scripts /run-vim.sh /run-chat.sh +/run-spec.sh /.ccache/ # IDE diff --git a/scripts/pr2wt.sh b/scripts/pr2wt.sh new file mode 100755 index 0000000000..22251339ac --- /dev/null +++ b/scripts/pr2wt.sh @@ -0,0 +1,65 @@ +#!/usr/bin/env bash + +# intialize a new worktree from a PR number: +# +# - creates a new remote using the fork's clone URL +# - creates a local branch tracking the remote branch +# - creates a new worktree in a parent folder, suffixed with "-pr-${PR}" +# +# sample usage: +# ./scripts/pr2wt.sh 12345 +# ./scripts/pr2wt.sh 12345 opencode + +function usage() { + echo "usage: $0 [cmd]" + exit 1 +} + +# check we are in the right directory +if [[ ! -f "scripts/pr2wt.sh" ]]; then + echo "error: this script must be run from the root of the repository" + exit 1 +fi + +if [[ $# -lt 1 || $# -gt 2 ]]; then + usage +fi + +PR=$1 +[[ "$PR" =~ ^[0-9]+$ ]] || { echo "error: PR number must be numeric"; exit 1; } + +url_origin=$(git config --get remote.origin.url) || { + echo "error: no remote named 'origin' in this repository" + exit 1 +} + +org_repo=$(echo $url_origin | cut -d/ -f4-) + +echo "org/repo: $org_repo" + +meta=$(curl -sSf -H "Accept: application/vnd.github+json" "https://api.github.com/repos/${org_repo}/pulls/${PR}") + +url_remote=$(echo "$meta" | jq -r '.head.repo.clone_url') +head_ref=$(echo "$meta" | jq -r '.head.ref') + +echo "url: $url_remote" +echo "head_ref: $head_ref" + +git remote rm pr/${PR} +git remote add pr/${PR} $url_remote +git fetch pr/${PR} $head_ref + +dir=$(basename $(pwd)) + +git branch -D pr/$PR 2> /dev/null +git worktree add -b pr/$PR ../$dir-pr-$PR pr/$PR/${head_ref} 2> /dev/null + +wt_path=$(cd ../$dir-pr-$PR && pwd) + +echo "git worktree created in $wt_path" + +# if a command was provided, execute it +if [[ $# -eq 2 ]]; then + cd ../$dir-pr-$PR + exec $2 +fi From 56d2fed2b3970ae55eebd0e5426d402304b1358a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Wed, 7 Jan 2026 16:18:26 +0100 Subject: [PATCH 08/27] tools : remove llama-run (#18661) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * tools : remove llama-run * Remove licenses/LICENSE-linenoise Signed-off-by: Adrien GallouĆ«t --- README.md | 16 - common/arg.cpp | 1 - licenses/LICENSE-linenoise | 26 - tools/CMakeLists.txt | 1 - tools/run/CMakeLists.txt | 23 - tools/run/README.md | 52 - tools/run/linenoise.cpp/linenoise.cpp | 1995 ------------------------- tools/run/linenoise.cpp/linenoise.h | 137 -- tools/run/run.cpp | 1408 ----------------- 9 files changed, 3659 deletions(-) delete mode 100644 licenses/LICENSE-linenoise delete mode 100644 tools/run/CMakeLists.txt delete mode 100644 tools/run/README.md delete mode 100644 tools/run/linenoise.cpp/linenoise.cpp delete mode 100644 tools/run/linenoise.cpp/linenoise.h delete mode 100644 tools/run/run.cpp diff --git a/README.md b/README.md index ed956bb02e..e59612f7ae 100644 --- a/README.md +++ b/README.md @@ -482,21 +482,6 @@ To learn more about model quantization, [read this documentation](tools/quantize -## [`llama-run`](tools/run) - -#### A comprehensive example for running `llama.cpp` models. Useful for inferencing. Used with RamaLama [^3]. - --
- Run a model with a specific prompt (by default it's pulled from Ollama registry) - - ```bash - llama-run granite-code - ``` - -
- -[^3]: [RamaLama](https://github.com/containers/ramalama) - ## [`llama-simple`](examples/simple) #### A minimal example for implementing apps with `llama.cpp`. Useful for developers. @@ -600,7 +585,6 @@ $ echo "source ~/.llama-completion.bash" >> ~/.bashrc - [stb-image](https://github.com/nothings/stb) - Single-header image format decoder, used by multimodal subsystem - Public domain - [nlohmann/json](https://github.com/nlohmann/json) - Single-header JSON library, used by various tools/examples - MIT License - [minja](https://github.com/google/minja) - Minimal Jinja parser in C++, used by various tools/examples - MIT License -- [linenoise.cpp](./tools/run/linenoise.cpp/linenoise.cpp) - C++ library that provides readline-like line editing capabilities, used by `llama-run` - BSD 2-Clause License - [curl](https://curl.se/) - Client-side URL transfer library, used by various tools/examples - [CURL License](https://curl.se/docs/copyright.html) - [miniaudio.h](https://github.com/mackron/miniaudio) - Single-header audio format decoder, used by multimodal subsystem - Public domain - [subprocess.h](https://github.com/sheredom/subprocess.h) - Single-header process launching solution for C and C++ - Public domain diff --git a/common/arg.cpp b/common/arg.cpp index a67a26e2dc..e7966d9d5c 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -679,7 +679,6 @@ static void common_params_print_completion(common_params_context & ctx_arg) { "llama-quantize", "llama-qwen2vl-cli", "llama-retrieval", - "llama-run", "llama-save-load-state", "llama-server", "llama-simple", diff --git a/licenses/LICENSE-linenoise b/licenses/LICENSE-linenoise deleted file mode 100644 index b006b3b24d..0000000000 --- a/licenses/LICENSE-linenoise +++ /dev/null @@ -1,26 +0,0 @@ -Copyright (c) 2010-2014, Salvatore Sanfilippo -Copyright (c) 2010-2013, Pieter Noordhuis -Copyright (c) 2025, Eric Curtin - -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR -ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON -ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 8df3f41003..48959fefb5 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -25,7 +25,6 @@ else() if (LLAMA_BUILD_SERVER) add_subdirectory(server) endif() - add_subdirectory(run) add_subdirectory(tokenize) add_subdirectory(tts) add_subdirectory(mtmd) diff --git a/tools/run/CMakeLists.txt b/tools/run/CMakeLists.txt deleted file mode 100644 index 6ad7534e29..0000000000 --- a/tools/run/CMakeLists.txt +++ /dev/null @@ -1,23 +0,0 @@ -set(TARGET llama-run) -add_executable(${TARGET} run.cpp linenoise.cpp/linenoise.cpp) - -# TODO: avoid copying this code block from common/CMakeLists.txt -set(LLAMA_RUN_EXTRA_LIBS "") -if (LLAMA_CURL) - find_package(CURL REQUIRED) - target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL) - include_directories(${CURL_INCLUDE_DIRS}) - set(LLAMA_RUN_EXTRA_LIBS ${LLAMA_RUN_EXTRA_LIBS} ${CURL_LIBRARIES}) -endif () - -if(LLAMA_TOOLS_INSTALL) - install(TARGETS ${TARGET} RUNTIME) -endif() - -if (CMAKE_SYSTEM_NAME MATCHES "AIX") - # AIX's flock() function comes from libbsd.a - target_link_libraries(${TARGET} PRIVATE -lbsd) -endif() - -target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT} ${LLAMA_RUN_EXTRA_LIBS}) -target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/tools/run/README.md b/tools/run/README.md deleted file mode 100644 index 5fd769b44c..0000000000 --- a/tools/run/README.md +++ /dev/null @@ -1,52 +0,0 @@ -# llama.cpp/example/run - -The purpose of this example is to demonstrate a minimal usage of llama.cpp for running models. - -```bash -llama-run granite3-moe -``` - -```bash -Description: - Runs a llm - -Usage: - llama-run [options] model [prompt] - -Options: - -c, --context-size - Context size (default: 2048) - -n, -ngl, --ngl - Number of GPU layers (default: 0) - --temp - Temperature (default: 0.8) - -v, --verbose, --log-verbose - Set verbosity level to infinity (i.e. log all messages, useful for debugging) - -h, --help - Show help message - -Commands: - model - Model is a string with an optional prefix of - huggingface:// (hf://), ollama://, https:// or file://. - If no protocol is specified and a file exists in the specified - path, file:// is assumed, otherwise if a file does not exist in - the specified path, ollama:// is assumed. Models that are being - pulled are downloaded with .partial extension while being - downloaded and then renamed as the file without the .partial - extension when complete. - -Examples: - llama-run llama3 - llama-run ollama://granite-code - llama-run ollama://smollm:135m - llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf - llama-run huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf - llama-run ms://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf - llama-run modelscope://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf - llama-run https://example.com/some-file1.gguf - llama-run some-file2.gguf - llama-run file://some-file3.gguf - llama-run --ngl 999 some-file4.gguf - llama-run --ngl 999 some-file5.gguf Hello World -``` diff --git a/tools/run/linenoise.cpp/linenoise.cpp b/tools/run/linenoise.cpp/linenoise.cpp deleted file mode 100644 index 9cb9399003..0000000000 --- a/tools/run/linenoise.cpp/linenoise.cpp +++ /dev/null @@ -1,1995 +0,0 @@ -#ifndef _WIN32 -/* - * You can find the latest source code at: - * - * http://github.com/ericcurtin/linenoise.cpp - * - * Does a number of crazy assumptions that happen to be true in 99.9999% of - * the 2010 UNIX computers around. - * - * ------------------------------------------------------------------------ - * - * Copyright (c) 2010-2023, Salvatore Sanfilippo - * Copyright (c) 2010-2013, Pieter Noordhuis - * Copyright (c) 2025, Eric Curtin - * - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - * ------------------------------------------------------------------------ - * - * References: - * - http://invisible-island.net/xterm/ctlseqs/ctlseqs.html - * - http://www.3waylabs.com/nw/WWW/products/wizcon/vt220.html - * - * Todo list: - * - Filter bogus Ctrl+ combinations. - * - Win32 support - * - * Bloat: - * - History search like Ctrl+r in readline? - * - * List of escape sequences used by this program, we do everything just - * with three sequences. In order to be so cheap we may have some - * flickering effect with some slow terminal, but the lesser sequences - * the more compatible. - * - * EL (Erase Line) - * Sequence: ESC [ n K - * Effect: if n is 0 or missing, clear from cursor to end of line - * Effect: if n is 1, clear from beginning of line to cursor - * Effect: if n is 2, clear entire line - * - * CUF (CUrsor Forward) - * Sequence: ESC [ n C - * Effect: moves cursor forward n chars - * - * CUB (CUrsor Backward) - * Sequence: ESC [ n D - * Effect: moves cursor backward n chars - * - * The following is used to get the terminal width if getting - * the width with the TIOCGWINSZ ioctl fails - * - * DSR (Device Status Report) - * Sequence: ESC [ 6 n - * Effect: reports the current cursor position as ESC [ n ; m R - * where n is the row and m is the column - * - * When multi line mode is enabled, we also use an additional escape - * sequence. However multi line editing is disabled by default. - * - * CUU (Cursor Up) - * Sequence: ESC [ n A - * Effect: moves cursor up of n chars. - * - * CUD (Cursor Down) - * Sequence: ESC [ n B - * Effect: moves cursor down of n chars. - * - * When linenoiseClearScreen() is called, two additional escape sequences - * are used in order to clear the screen and position the cursor at home - * position. - * - * CUP (Cursor position) - * Sequence: ESC [ H - * Effect: moves the cursor to upper left corner - * - * ED (Erase display) - * Sequence: ESC [ 2 J - * Effect: clear the whole screen - * - */ - -# include "linenoise.h" - -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include - -# include -# include -# include - -# define LINENOISE_DEFAULT_HISTORY_MAX_LEN 100 -# define LINENOISE_MAX_LINE 4096 -static std::vector unsupported_term = { "dumb", "cons25", "emacs" }; -static linenoiseCompletionCallback *completionCallback = NULL; -static linenoiseHintsCallback *hintsCallback = NULL; -static linenoiseFreeHintsCallback *freeHintsCallback = NULL; -static char *linenoiseNoTTY(void); -static void refreshLineWithCompletion(struct linenoiseState *ls, linenoiseCompletions *lc, int flags); -static void refreshLineWithFlags(struct linenoiseState *l, int flags); - -static struct termios orig_termios; /* In order to restore at exit.*/ -static int maskmode = 0; /* Show "***" instead of input. For passwords. */ -static int rawmode = 0; /* For atexit() function to check if restore is needed*/ -static int mlmode = 0; /* Multi line mode. Default is single line. */ -static int atexit_registered = 0; /* Register atexit just 1 time. */ -static int history_max_len = LINENOISE_DEFAULT_HISTORY_MAX_LEN; -static int history_len = 0; -static char **history = NULL; - -enum KEY_ACTION{ - KEY_NULL = 0, /* NULL */ - CTRL_A = 1, /* Ctrl+a */ - CTRL_B = 2, /* Ctrl-b */ - CTRL_C = 3, /* Ctrl-c */ - CTRL_D = 4, /* Ctrl-d */ - CTRL_E = 5, /* Ctrl-e */ - CTRL_F = 6, /* Ctrl-f */ - CTRL_H = 8, /* Ctrl-h */ - TAB = 9, /* Tab */ - CTRL_K = 11, /* Ctrl+k */ - CTRL_L = 12, /* Ctrl+l */ - ENTER = 13, /* Enter */ - CTRL_N = 14, /* Ctrl-n */ - CTRL_P = 16, /* Ctrl-p */ - CTRL_T = 20, /* Ctrl-t */ - CTRL_U = 21, /* Ctrl+u */ - CTRL_W = 23, /* Ctrl+w */ - ESC = 27, /* Escape */ - BACKSPACE = 127 /* Backspace */ -}; - -static void linenoiseAtExit(void); -int linenoiseHistoryAdd(const char *line); -#define REFRESH_CLEAN (1<<0) // Clean the old prompt from the screen -#define REFRESH_WRITE (1<<1) // Rewrite the prompt on the screen. -#define REFRESH_ALL (REFRESH_CLEAN|REFRESH_WRITE) // Do both. -static void refreshLine(struct linenoiseState *l); - -class File { - public: - FILE * file = nullptr; - - FILE * open(const std::string & filename, const char * mode) { - file = fopen(filename.c_str(), mode); - - return file; - } - - int lock() { - if (file) { - fd = fileno(file); - if (flock(fd, LOCK_EX | LOCK_NB) != 0) { - fd = -1; - - return 1; - } - } - - return 0; - } - - ~File() { - if (fd >= 0) { - flock(fd, LOCK_UN); - } - - if (file) { - fclose(file); - } - } - - private: - int fd = -1; -}; - -#if 0 -/* Debugging function. */ -__attribute__((format(printf, 1, 2))) -static void lndebug(const char *fmt, ...) { - static File file; - if (file.file == nullptr) { - file.open("/tmp/lndebug.txt", "a"); - } - - if (file.file != nullptr) { - va_list args; - va_start(args, fmt); - vfprintf(file.file, fmt, args); - va_end(args); - fflush(file.file); - } -} -#endif - -/* ========================== Encoding functions ============================= */ - -/* Get length of previous UTF8 codepoint */ -static size_t prevUtf8CodePointLen(const char * buf, int pos) { - int end = pos--; - while (pos >= 0 && ((unsigned char) buf[pos] & 0xC0) == 0x80) { - pos--; - } - return end - pos; -} - -/* Convert UTF8 to Unicode code point */ -static size_t utf8BytesToCodePoint(const char * buf, size_t len, int * cp) { - if (len) { - unsigned char byte = buf[0]; - if ((byte & 0x80) == 0) { - *cp = byte; - return 1; - } else if ((byte & 0xE0) == 0xC0) { - if (len >= 2) { - *cp = (((unsigned long) (buf[0] & 0x1F)) << 6) | ((unsigned long) (buf[1] & 0x3F)); - return 2; - } - } else if ((byte & 0xF0) == 0xE0) { - if (len >= 3) { - *cp = (((unsigned long) (buf[0] & 0x0F)) << 12) | (((unsigned long) (buf[1] & 0x3F)) << 6) | - ((unsigned long) (buf[2] & 0x3F)); - return 3; - } - } else if ((byte & 0xF8) == 0xF0) { - if (len >= 4) { - *cp = (((unsigned long) (buf[0] & 0x07)) << 18) | (((unsigned long) (buf[1] & 0x3F)) << 12) | - (((unsigned long) (buf[2] & 0x3F)) << 6) | ((unsigned long) (buf[3] & 0x3F)); - return 4; - } - } - } - return 0; -} - -/* Check if the code is a wide character */ -static const unsigned long wideCharTable[][2] = { - /* BEGIN: WIDE CHAR TABLE */ - { 0x1100, 0x115F }, - { 0x231A, 0x231B }, - { 0x2329, 0x232A }, - { 0x23E9, 0x23EC }, - { 0x23F0, 0x23F0 }, - { 0x23F3, 0x23F3 }, - { 0x25FD, 0x25FE }, - { 0x2614, 0x2615 }, - { 0x2630, 0x2637 }, - { 0x2648, 0x2653 }, - { 0x267F, 0x267F }, - { 0x268A, 0x268F }, - { 0x2693, 0x2693 }, - { 0x26A1, 0x26A1 }, - { 0x26AA, 0x26AB }, - { 0x26BD, 0x26BE }, - { 0x26C4, 0x26C5 }, - { 0x26CE, 0x26CE }, - { 0x26D4, 0x26D4 }, - { 0x26EA, 0x26EA }, - { 0x26F2, 0x26F3 }, - { 0x26F5, 0x26F5 }, - { 0x26FA, 0x26FA }, - { 0x26FD, 0x26FD }, - { 0x2705, 0x2705 }, - { 0x270A, 0x270B }, - { 0x2728, 0x2728 }, - { 0x274C, 0x274C }, - { 0x274E, 0x274E }, - { 0x2753, 0x2755 }, - { 0x2757, 0x2757 }, - { 0x2795, 0x2797 }, - { 0x27B0, 0x27B0 }, - { 0x27BF, 0x27BF }, - { 0x2B1B, 0x2B1C }, - { 0x2B50, 0x2B50 }, - { 0x2B55, 0x2B55 }, - { 0x2E80, 0x2E99 }, - { 0x2E9B, 0x2EF3 }, - { 0x2F00, 0x2FD5 }, - { 0x2FF0, 0x303E }, - { 0x3041, 0x3096 }, - { 0x3099, 0x30FF }, - { 0x3105, 0x312F }, - { 0x3131, 0x318E }, - { 0x3190, 0x31E5 }, - { 0x31EF, 0x321E }, - { 0x3220, 0x3247 }, - { 0x3250, 0xA48C }, - { 0xA490, 0xA4C6 }, - { 0xA960, 0xA97C }, - { 0xAC00, 0xD7A3 }, - { 0xF900, 0xFAFF }, - { 0xFE10, 0xFE19 }, - { 0xFE30, 0xFE52 }, - { 0xFE54, 0xFE66 }, - { 0xFE68, 0xFE6B }, - { 0xFF01, 0xFF60 }, - { 0xFFE0, 0xFFE6 }, - { 0x16FE0, 0x16FE4 }, - { 0x16FF0, 0x16FF1 }, - { 0x17000, 0x187F7 }, - { 0x18800, 0x18CD5 }, - { 0x18CFF, 0x18D08 }, - { 0x1AFF0, 0x1AFF3 }, - { 0x1AFF5, 0x1AFFB }, - { 0x1AFFD, 0x1AFFE }, - { 0x1B000, 0x1B122 }, - { 0x1B132, 0x1B132 }, - { 0x1B150, 0x1B152 }, - { 0x1B155, 0x1B155 }, - { 0x1B164, 0x1B167 }, - { 0x1B170, 0x1B2FB }, - { 0x1D300, 0x1D356 }, - { 0x1D360, 0x1D376 }, - { 0x1F004, 0x1F004 }, - { 0x1F0CF, 0x1F0CF }, - { 0x1F18E, 0x1F18E }, - { 0x1F191, 0x1F19A }, - { 0x1F200, 0x1F202 }, - { 0x1F210, 0x1F23B }, - { 0x1F240, 0x1F248 }, - { 0x1F250, 0x1F251 }, - { 0x1F260, 0x1F265 }, - { 0x1F300, 0x1F320 }, - { 0x1F32D, 0x1F335 }, - { 0x1F337, 0x1F37C }, - { 0x1F37E, 0x1F393 }, - { 0x1F3A0, 0x1F3CA }, - { 0x1F3CF, 0x1F3D3 }, - { 0x1F3E0, 0x1F3F0 }, - { 0x1F3F4, 0x1F3F4 }, - { 0x1F3F8, 0x1F43E }, - { 0x1F440, 0x1F440 }, - { 0x1F442, 0x1F4FC }, - { 0x1F4FF, 0x1F53D }, - { 0x1F54B, 0x1F54E }, - { 0x1F550, 0x1F567 }, - { 0x1F57A, 0x1F57A }, - { 0x1F595, 0x1F596 }, - { 0x1F5A4, 0x1F5A4 }, - { 0x1F5FB, 0x1F64F }, - { 0x1F680, 0x1F6C5 }, - { 0x1F6CC, 0x1F6CC }, - { 0x1F6D0, 0x1F6D2 }, - { 0x1F6D5, 0x1F6D7 }, - { 0x1F6DC, 0x1F6DF }, - { 0x1F6EB, 0x1F6EC }, - { 0x1F6F4, 0x1F6FC }, - { 0x1F7E0, 0x1F7EB }, - { 0x1F7F0, 0x1F7F0 }, - { 0x1F90C, 0x1F93A }, - { 0x1F93C, 0x1F945 }, - { 0x1F947, 0x1F9FF }, - { 0x1FA70, 0x1FA7C }, - { 0x1FA80, 0x1FA89 }, - { 0x1FA8F, 0x1FAC6 }, - { 0x1FACE, 0x1FADC }, - { 0x1FADF, 0x1FAE9 }, - { 0x1FAF0, 0x1FAF8 }, - { 0x20000, 0x2FFFD }, - { 0x30000, 0x3FFFD } - /* END: WIDE CHAR TABLE */ -}; - -static const size_t wideCharTableSize = sizeof(wideCharTable) / sizeof(wideCharTable[0]); - -static bool isWideChar(unsigned long cp) { - for (size_t i = 0; i < wideCharTableSize; i++) { - auto first_code = wideCharTable[i][0]; - auto last_code = wideCharTable[i][1]; - if (first_code > cp) { - return false; - } - if (first_code <= cp && cp <= last_code) { - return true; - } - } - return false; -} - -/* Check if the code is a combining character */ -static const unsigned long combiningCharTable[] = { - /* BEGIN: COMBINING CHAR TABLE */ - 0x0300, 0x0301, 0x0302, 0x0303, 0x0304, 0x0305, 0x0306, 0x0307, 0x0308, 0x0309, 0x030A, 0x030B, 0x030C, - 0x030D, 0x030E, 0x030F, 0x0310, 0x0311, 0x0312, 0x0313, 0x0314, 0x0315, 0x0316, 0x0317, 0x0318, 0x0319, - 0x031A, 0x031B, 0x031C, 0x031D, 0x031E, 0x031F, 0x0320, 0x0321, 0x0322, 0x0323, 0x0324, 0x0325, 0x0326, - 0x0327, 0x0328, 0x0329, 0x032A, 0x032B, 0x032C, 0x032D, 0x032E, 0x032F, 0x0330, 0x0331, 0x0332, 0x0333, - 0x0334, 0x0335, 0x0336, 0x0337, 0x0338, 0x0339, 0x033A, 0x033B, 0x033C, 0x033D, 0x033E, 0x033F, 0x0340, - 0x0341, 0x0342, 0x0343, 0x0344, 0x0345, 0x0346, 0x0347, 0x0348, 0x0349, 0x034A, 0x034B, 0x034C, 0x034D, - 0x034E, 0x034F, 0x0350, 0x0351, 0x0352, 0x0353, 0x0354, 0x0355, 0x0356, 0x0357, 0x0358, 0x0359, 0x035A, - 0x035B, 0x035C, 0x035D, 0x035E, 0x035F, 0x0360, 0x0361, 0x0362, 0x0363, 0x0364, 0x0365, 0x0366, 0x0367, - 0x0368, 0x0369, 0x036A, 0x036B, 0x036C, 0x036D, 0x036E, 0x036F, 0x0483, 0x0484, 0x0485, 0x0486, 0x0487, - 0x0591, 0x0592, 0x0593, 0x0594, 0x0595, 0x0596, 0x0597, 0x0598, 0x0599, 0x059A, 0x059B, 0x059C, 0x059D, - 0x059E, 0x059F, 0x05A0, 0x05A1, 0x05A2, 0x05A3, 0x05A4, 0x05A5, 0x05A6, 0x05A7, 0x05A8, 0x05A9, 0x05AA, - 0x05AB, 0x05AC, 0x05AD, 0x05AE, 0x05AF, 0x05B0, 0x05B1, 0x05B2, 0x05B3, 0x05B4, 0x05B5, 0x05B6, 0x05B7, - 0x05B8, 0x05B9, 0x05BA, 0x05BB, 0x05BC, 0x05BD, 0x05BF, 0x05C1, 0x05C2, 0x05C4, 0x05C5, 0x05C7, 0x0610, - 0x0611, 0x0612, 0x0613, 0x0614, 0x0615, 0x0616, 0x0617, 0x0618, 0x0619, 0x061A, 0x064B, 0x064C, 0x064D, - 0x064E, 0x064F, 0x0650, 0x0651, 0x0652, 0x0653, 0x0654, 0x0655, 0x0656, 0x0657, 0x0658, 0x0659, 0x065A, - 0x065B, 0x065C, 0x065D, 0x065E, 0x065F, 0x0670, 0x06D6, 0x06D7, 0x06D8, 0x06D9, 0x06DA, 0x06DB, 0x06DC, - 0x06DF, 0x06E0, 0x06E1, 0x06E2, 0x06E3, 0x06E4, 0x06E7, 0x06E8, 0x06EA, 0x06EB, 0x06EC, 0x06ED, 0x0711, - 0x0730, 0x0731, 0x0732, 0x0733, 0x0734, 0x0735, 0x0736, 0x0737, 0x0738, 0x0739, 0x073A, 0x073B, 0x073C, - 0x073D, 0x073E, 0x073F, 0x0740, 0x0741, 0x0742, 0x0743, 0x0744, 0x0745, 0x0746, 0x0747, 0x0748, 0x0749, - 0x074A, 0x07A6, 0x07A7, 0x07A8, 0x07A9, 0x07AA, 0x07AB, 0x07AC, 0x07AD, 0x07AE, 0x07AF, 0x07B0, 0x07EB, - 0x07EC, 0x07ED, 0x07EE, 0x07EF, 0x07F0, 0x07F1, 0x07F2, 0x07F3, 0x07FD, 0x0816, 0x0817, 0x0818, 0x0819, - 0x081B, 0x081C, 0x081D, 0x081E, 0x081F, 0x0820, 0x0821, 0x0822, 0x0823, 0x0825, 0x0826, 0x0827, 0x0829, - 0x082A, 0x082B, 0x082C, 0x082D, 0x0859, 0x085A, 0x085B, 0x0897, 0x0898, 0x0899, 0x089A, 0x089B, 0x089C, - 0x089D, 0x089E, 0x089F, 0x08CA, 0x08CB, 0x08CC, 0x08CD, 0x08CE, 0x08CF, 0x08D0, 0x08D1, 0x08D2, 0x08D3, - 0x08D4, 0x08D5, 0x08D6, 0x08D7, 0x08D8, 0x08D9, 0x08DA, 0x08DB, 0x08DC, 0x08DD, 0x08DE, 0x08DF, 0x08E0, - 0x08E1, 0x08E3, 0x08E4, 0x08E5, 0x08E6, 0x08E7, 0x08E8, 0x08E9, 0x08EA, 0x08EB, 0x08EC, 0x08ED, 0x08EE, - 0x08EF, 0x08F0, 0x08F1, 0x08F2, 0x08F3, 0x08F4, 0x08F5, 0x08F6, 0x08F7, 0x08F8, 0x08F9, 0x08FA, 0x08FB, - 0x08FC, 0x08FD, 0x08FE, 0x08FF, 0x0900, 0x0901, 0x0902, 0x093A, 0x093C, 0x0941, 0x0942, 0x0943, 0x0944, - 0x0945, 0x0946, 0x0947, 0x0948, 0x094D, 0x0951, 0x0952, 0x0953, 0x0954, 0x0955, 0x0956, 0x0957, 0x0962, - 0x0963, 0x0981, 0x09BC, 0x09C1, 0x09C2, 0x09C3, 0x09C4, 0x09CD, 0x09E2, 0x09E3, 0x09FE, 0x0A01, 0x0A02, - 0x0A3C, 0x0A41, 0x0A42, 0x0A47, 0x0A48, 0x0A4B, 0x0A4C, 0x0A4D, 0x0A51, 0x0A70, 0x0A71, 0x0A75, 0x0A81, - 0x0A82, 0x0ABC, 0x0AC1, 0x0AC2, 0x0AC3, 0x0AC4, 0x0AC5, 0x0AC7, 0x0AC8, 0x0ACD, 0x0AE2, 0x0AE3, 0x0AFA, - 0x0AFB, 0x0AFC, 0x0AFD, 0x0AFE, 0x0AFF, 0x0B01, 0x0B3C, 0x0B3F, 0x0B41, 0x0B42, 0x0B43, 0x0B44, 0x0B4D, - 0x0B55, 0x0B56, 0x0B62, 0x0B63, 0x0B82, 0x0BC0, 0x0BCD, 0x0C00, 0x0C04, 0x0C3C, 0x0C3E, 0x0C3F, 0x0C40, - 0x0C46, 0x0C47, 0x0C48, 0x0C4A, 0x0C4B, 0x0C4C, 0x0C4D, 0x0C55, 0x0C56, 0x0C62, 0x0C63, 0x0C81, 0x0CBC, - 0x0CBF, 0x0CC6, 0x0CCC, 0x0CCD, 0x0CE2, 0x0CE3, 0x0D00, 0x0D01, 0x0D3B, 0x0D3C, 0x0D41, 0x0D42, 0x0D43, - 0x0D44, 0x0D4D, 0x0D62, 0x0D63, 0x0D81, 0x0DCA, 0x0DD2, 0x0DD3, 0x0DD4, 0x0DD6, 0x0E31, 0x0E34, 0x0E35, - 0x0E36, 0x0E37, 0x0E38, 0x0E39, 0x0E3A, 0x0E47, 0x0E48, 0x0E49, 0x0E4A, 0x0E4B, 0x0E4C, 0x0E4D, 0x0E4E, - 0x0EB1, 0x0EB4, 0x0EB5, 0x0EB6, 0x0EB7, 0x0EB8, 0x0EB9, 0x0EBA, 0x0EBB, 0x0EBC, 0x0EC8, 0x0EC9, 0x0ECA, - 0x0ECB, 0x0ECC, 0x0ECD, 0x0ECE, 0x0F18, 0x0F19, 0x0F35, 0x0F37, 0x0F39, 0x0F71, 0x0F72, 0x0F73, 0x0F74, - 0x0F75, 0x0F76, 0x0F77, 0x0F78, 0x0F79, 0x0F7A, 0x0F7B, 0x0F7C, 0x0F7D, 0x0F7E, 0x0F80, 0x0F81, 0x0F82, - 0x0F83, 0x0F84, 0x0F86, 0x0F87, 0x0F8D, 0x0F8E, 0x0F8F, 0x0F90, 0x0F91, 0x0F92, 0x0F93, 0x0F94, 0x0F95, - 0x0F96, 0x0F97, 0x0F99, 0x0F9A, 0x0F9B, 0x0F9C, 0x0F9D, 0x0F9E, 0x0F9F, 0x0FA0, 0x0FA1, 0x0FA2, 0x0FA3, - 0x0FA4, 0x0FA5, 0x0FA6, 0x0FA7, 0x0FA8, 0x0FA9, 0x0FAA, 0x0FAB, 0x0FAC, 0x0FAD, 0x0FAE, 0x0FAF, 0x0FB0, - 0x0FB1, 0x0FB2, 0x0FB3, 0x0FB4, 0x0FB5, 0x0FB6, 0x0FB7, 0x0FB8, 0x0FB9, 0x0FBA, 0x0FBB, 0x0FBC, 0x0FC6, - 0x102D, 0x102E, 0x102F, 0x1030, 0x1032, 0x1033, 0x1034, 0x1035, 0x1036, 0x1037, 0x1039, 0x103A, 0x103D, - 0x103E, 0x1058, 0x1059, 0x105E, 0x105F, 0x1060, 0x1071, 0x1072, 0x1073, 0x1074, 0x1082, 0x1085, 0x1086, - 0x108D, 0x109D, 0x135D, 0x135E, 0x135F, 0x1712, 0x1713, 0x1714, 0x1732, 0x1733, 0x1752, 0x1753, 0x1772, - 0x1773, 0x17B4, 0x17B5, 0x17B7, 0x17B8, 0x17B9, 0x17BA, 0x17BB, 0x17BC, 0x17BD, 0x17C6, 0x17C9, 0x17CA, - 0x17CB, 0x17CC, 0x17CD, 0x17CE, 0x17CF, 0x17D0, 0x17D1, 0x17D2, 0x17D3, 0x17DD, 0x180B, 0x180C, 0x180D, - 0x180F, 0x1885, 0x1886, 0x18A9, 0x1920, 0x1921, 0x1922, 0x1927, 0x1928, 0x1932, 0x1939, 0x193A, 0x193B, - 0x1A17, 0x1A18, 0x1A1B, 0x1A56, 0x1A58, 0x1A59, 0x1A5A, 0x1A5B, 0x1A5C, 0x1A5D, 0x1A5E, 0x1A60, 0x1A62, - 0x1A65, 0x1A66, 0x1A67, 0x1A68, 0x1A69, 0x1A6A, 0x1A6B, 0x1A6C, 0x1A73, 0x1A74, 0x1A75, 0x1A76, 0x1A77, - 0x1A78, 0x1A79, 0x1A7A, 0x1A7B, 0x1A7C, 0x1A7F, 0x1AB0, 0x1AB1, 0x1AB2, 0x1AB3, 0x1AB4, 0x1AB5, 0x1AB6, - 0x1AB7, 0x1AB8, 0x1AB9, 0x1ABA, 0x1ABB, 0x1ABC, 0x1ABD, 0x1ABF, 0x1AC0, 0x1AC1, 0x1AC2, 0x1AC3, 0x1AC4, - 0x1AC5, 0x1AC6, 0x1AC7, 0x1AC8, 0x1AC9, 0x1ACA, 0x1ACB, 0x1ACC, 0x1ACD, 0x1ACE, 0x1B00, 0x1B01, 0x1B02, - 0x1B03, 0x1B34, 0x1B36, 0x1B37, 0x1B38, 0x1B39, 0x1B3A, 0x1B3C, 0x1B42, 0x1B6B, 0x1B6C, 0x1B6D, 0x1B6E, - 0x1B6F, 0x1B70, 0x1B71, 0x1B72, 0x1B73, 0x1B80, 0x1B81, 0x1BA2, 0x1BA3, 0x1BA4, 0x1BA5, 0x1BA8, 0x1BA9, - 0x1BAB, 0x1BAC, 0x1BAD, 0x1BE6, 0x1BE8, 0x1BE9, 0x1BED, 0x1BEF, 0x1BF0, 0x1BF1, 0x1C2C, 0x1C2D, 0x1C2E, - 0x1C2F, 0x1C30, 0x1C31, 0x1C32, 0x1C33, 0x1C36, 0x1C37, 0x1CD0, 0x1CD1, 0x1CD2, 0x1CD4, 0x1CD5, 0x1CD6, - 0x1CD7, 0x1CD8, 0x1CD9, 0x1CDA, 0x1CDB, 0x1CDC, 0x1CDD, 0x1CDE, 0x1CDF, 0x1CE0, 0x1CE2, 0x1CE3, 0x1CE4, - 0x1CE5, 0x1CE6, 0x1CE7, 0x1CE8, 0x1CED, 0x1CF4, 0x1CF8, 0x1CF9, 0x1DC0, 0x1DC1, 0x1DC2, 0x1DC3, 0x1DC4, - 0x1DC5, 0x1DC6, 0x1DC7, 0x1DC8, 0x1DC9, 0x1DCA, 0x1DCB, 0x1DCC, 0x1DCD, 0x1DCE, 0x1DCF, 0x1DD0, 0x1DD1, - 0x1DD2, 0x1DD3, 0x1DD4, 0x1DD5, 0x1DD6, 0x1DD7, 0x1DD8, 0x1DD9, 0x1DDA, 0x1DDB, 0x1DDC, 0x1DDD, 0x1DDE, - 0x1DDF, 0x1DE0, 0x1DE1, 0x1DE2, 0x1DE3, 0x1DE4, 0x1DE5, 0x1DE6, 0x1DE7, 0x1DE8, 0x1DE9, 0x1DEA, 0x1DEB, - 0x1DEC, 0x1DED, 0x1DEE, 0x1DEF, 0x1DF0, 0x1DF1, 0x1DF2, 0x1DF3, 0x1DF4, 0x1DF5, 0x1DF6, 0x1DF7, 0x1DF8, - 0x1DF9, 0x1DFA, 0x1DFB, 0x1DFC, 0x1DFD, 0x1DFE, 0x1DFF, 0x20D0, 0x20D1, 0x20D2, 0x20D3, 0x20D4, 0x20D5, - 0x20D6, 0x20D7, 0x20D8, 0x20D9, 0x20DA, 0x20DB, 0x20DC, 0x20E1, 0x20E5, 0x20E6, 0x20E7, 0x20E8, 0x20E9, - 0x20EA, 0x20EB, 0x20EC, 0x20ED, 0x20EE, 0x20EF, 0x20F0, 0x2CEF, 0x2CF0, 0x2CF1, 0x2D7F, 0x2DE0, 0x2DE1, - 0x2DE2, 0x2DE3, 0x2DE4, 0x2DE5, 0x2DE6, 0x2DE7, 0x2DE8, 0x2DE9, 0x2DEA, 0x2DEB, 0x2DEC, 0x2DED, 0x2DEE, - 0x2DEF, 0x2DF0, 0x2DF1, 0x2DF2, 0x2DF3, 0x2DF4, 0x2DF5, 0x2DF6, 0x2DF7, 0x2DF8, 0x2DF9, 0x2DFA, 0x2DFB, - 0x2DFC, 0x2DFD, 0x2DFE, 0x2DFF, 0x302A, 0x302B, 0x302C, 0x302D, 0x3099, 0x309A, 0xA66F, 0xA674, 0xA675, - 0xA676, 0xA677, 0xA678, 0xA679, 0xA67A, 0xA67B, 0xA67C, 0xA67D, 0xA69E, 0xA69F, 0xA6F0, 0xA6F1, 0xA802, - 0xA806, 0xA80B, 0xA825, 0xA826, 0xA82C, 0xA8C4, 0xA8C5, 0xA8E0, 0xA8E1, 0xA8E2, 0xA8E3, 0xA8E4, 0xA8E5, - 0xA8E6, 0xA8E7, 0xA8E8, 0xA8E9, 0xA8EA, 0xA8EB, 0xA8EC, 0xA8ED, 0xA8EE, 0xA8EF, 0xA8F0, 0xA8F1, 0xA8FF, - 0xA926, 0xA927, 0xA928, 0xA929, 0xA92A, 0xA92B, 0xA92C, 0xA92D, 0xA947, 0xA948, 0xA949, 0xA94A, 0xA94B, - 0xA94C, 0xA94D, 0xA94E, 0xA94F, 0xA950, 0xA951, 0xA980, 0xA981, 0xA982, 0xA9B3, 0xA9B6, 0xA9B7, 0xA9B8, - 0xA9B9, 0xA9BC, 0xA9BD, 0xA9E5, 0xAA29, 0xAA2A, 0xAA2B, 0xAA2C, 0xAA2D, 0xAA2E, 0xAA31, 0xAA32, 0xAA35, - 0xAA36, 0xAA43, 0xAA4C, 0xAA7C, 0xAAB0, 0xAAB2, 0xAAB3, 0xAAB4, 0xAAB7, 0xAAB8, 0xAABE, 0xAABF, 0xAAC1, - 0xAAEC, 0xAAED, 0xAAF6, 0xABE5, 0xABE8, 0xABED, 0xFB1E, 0xFE00, 0xFE01, 0xFE02, 0xFE03, 0xFE04, 0xFE05, - 0xFE06, 0xFE07, 0xFE08, 0xFE09, 0xFE0A, 0xFE0B, 0xFE0C, 0xFE0D, 0xFE0E, 0xFE0F, 0xFE20, 0xFE21, 0xFE22, - 0xFE23, 0xFE24, 0xFE25, 0xFE26, 0xFE27, 0xFE28, 0xFE29, 0xFE2A, 0xFE2B, 0xFE2C, 0xFE2D, 0xFE2E, 0xFE2F, - 0x101FD, 0x102E0, 0x10376, 0x10377, 0x10378, 0x10379, 0x1037A, 0x10A01, 0x10A02, 0x10A03, 0x10A05, 0x10A06, 0x10A0C, - 0x10A0D, 0x10A0E, 0x10A0F, 0x10A38, 0x10A39, 0x10A3A, 0x10A3F, 0x10AE5, 0x10AE6, 0x10D24, 0x10D25, 0x10D26, 0x10D27, - 0x10D69, 0x10D6A, 0x10D6B, 0x10D6C, 0x10D6D, 0x10EAB, 0x10EAC, 0x10EFC, 0x10EFD, 0x10EFE, 0x10EFF, 0x10F46, 0x10F47, - 0x10F48, 0x10F49, 0x10F4A, 0x10F4B, 0x10F4C, 0x10F4D, 0x10F4E, 0x10F4F, 0x10F50, 0x10F82, 0x10F83, 0x10F84, 0x10F85, - 0x11001, 0x11038, 0x11039, 0x1103A, 0x1103B, 0x1103C, 0x1103D, 0x1103E, 0x1103F, 0x11040, 0x11041, 0x11042, 0x11043, - 0x11044, 0x11045, 0x11046, 0x11070, 0x11073, 0x11074, 0x1107F, 0x11080, 0x11081, 0x110B3, 0x110B4, 0x110B5, 0x110B6, - 0x110B9, 0x110BA, 0x110C2, 0x11100, 0x11101, 0x11102, 0x11127, 0x11128, 0x11129, 0x1112A, 0x1112B, 0x1112D, 0x1112E, - 0x1112F, 0x11130, 0x11131, 0x11132, 0x11133, 0x11134, 0x11173, 0x11180, 0x11181, 0x111B6, 0x111B7, 0x111B8, 0x111B9, - 0x111BA, 0x111BB, 0x111BC, 0x111BD, 0x111BE, 0x111C9, 0x111CA, 0x111CB, 0x111CC, 0x111CF, 0x1122F, 0x11230, 0x11231, - 0x11234, 0x11236, 0x11237, 0x1123E, 0x11241, 0x112DF, 0x112E3, 0x112E4, 0x112E5, 0x112E6, 0x112E7, 0x112E8, 0x112E9, - 0x112EA, 0x11300, 0x11301, 0x1133B, 0x1133C, 0x11340, 0x11366, 0x11367, 0x11368, 0x11369, 0x1136A, 0x1136B, 0x1136C, - 0x11370, 0x11371, 0x11372, 0x11373, 0x11374, 0x113BB, 0x113BC, 0x113BD, 0x113BE, 0x113BF, 0x113C0, 0x113CE, 0x113D0, - 0x113D2, 0x113E1, 0x113E2, 0x11438, 0x11439, 0x1143A, 0x1143B, 0x1143C, 0x1143D, 0x1143E, 0x1143F, 0x11442, 0x11443, - 0x11444, 0x11446, 0x1145E, 0x114B3, 0x114B4, 0x114B5, 0x114B6, 0x114B7, 0x114B8, 0x114BA, 0x114BF, 0x114C0, 0x114C2, - 0x114C3, 0x115B2, 0x115B3, 0x115B4, 0x115B5, 0x115BC, 0x115BD, 0x115BF, 0x115C0, 0x115DC, 0x115DD, 0x11633, 0x11634, - 0x11635, 0x11636, 0x11637, 0x11638, 0x11639, 0x1163A, 0x1163D, 0x1163F, 0x11640, 0x116AB, 0x116AD, 0x116B0, 0x116B1, - 0x116B2, 0x116B3, 0x116B4, 0x116B5, 0x116B7, 0x1171D, 0x1171F, 0x11722, 0x11723, 0x11724, 0x11725, 0x11727, 0x11728, - 0x11729, 0x1172A, 0x1172B, 0x1182F, 0x11830, 0x11831, 0x11832, 0x11833, 0x11834, 0x11835, 0x11836, 0x11837, 0x11839, - 0x1183A, 0x1193B, 0x1193C, 0x1193E, 0x11943, 0x119D4, 0x119D5, 0x119D6, 0x119D7, 0x119DA, 0x119DB, 0x119E0, 0x11A01, - 0x11A02, 0x11A03, 0x11A04, 0x11A05, 0x11A06, 0x11A07, 0x11A08, 0x11A09, 0x11A0A, 0x11A33, 0x11A34, 0x11A35, 0x11A36, - 0x11A37, 0x11A38, 0x11A3B, 0x11A3C, 0x11A3D, 0x11A3E, 0x11A47, 0x11A51, 0x11A52, 0x11A53, 0x11A54, 0x11A55, 0x11A56, - 0x11A59, 0x11A5A, 0x11A5B, 0x11A8A, 0x11A8B, 0x11A8C, 0x11A8D, 0x11A8E, 0x11A8F, 0x11A90, 0x11A91, 0x11A92, 0x11A93, - 0x11A94, 0x11A95, 0x11A96, 0x11A98, 0x11A99, 0x11C30, 0x11C31, 0x11C32, 0x11C33, 0x11C34, 0x11C35, 0x11C36, 0x11C38, - 0x11C39, 0x11C3A, 0x11C3B, 0x11C3C, 0x11C3D, 0x11C3F, 0x11C92, 0x11C93, 0x11C94, 0x11C95, 0x11C96, 0x11C97, 0x11C98, - 0x11C99, 0x11C9A, 0x11C9B, 0x11C9C, 0x11C9D, 0x11C9E, 0x11C9F, 0x11CA0, 0x11CA1, 0x11CA2, 0x11CA3, 0x11CA4, 0x11CA5, - 0x11CA6, 0x11CA7, 0x11CAA, 0x11CAB, 0x11CAC, 0x11CAD, 0x11CAE, 0x11CAF, 0x11CB0, 0x11CB2, 0x11CB3, 0x11CB5, 0x11CB6, - 0x11D31, 0x11D32, 0x11D33, 0x11D34, 0x11D35, 0x11D36, 0x11D3A, 0x11D3C, 0x11D3D, 0x11D3F, 0x11D40, 0x11D41, 0x11D42, - 0x11D43, 0x11D44, 0x11D45, 0x11D47, 0x11D90, 0x11D91, 0x11D95, 0x11D97, 0x11EF3, 0x11EF4, 0x11F00, 0x11F01, 0x11F36, - 0x11F37, 0x11F38, 0x11F39, 0x11F3A, 0x11F40, 0x11F42, 0x11F5A, 0x13440, 0x13447, 0x13448, 0x13449, 0x1344A, 0x1344B, - 0x1344C, 0x1344D, 0x1344E, 0x1344F, 0x13450, 0x13451, 0x13452, 0x13453, 0x13454, 0x13455, 0x1611E, 0x1611F, 0x16120, - 0x16121, 0x16122, 0x16123, 0x16124, 0x16125, 0x16126, 0x16127, 0x16128, 0x16129, 0x1612D, 0x1612E, 0x1612F, 0x16AF0, - 0x16AF1, 0x16AF2, 0x16AF3, 0x16AF4, 0x16B30, 0x16B31, 0x16B32, 0x16B33, 0x16B34, 0x16B35, 0x16B36, 0x16F4F, 0x16F8F, - 0x16F90, 0x16F91, 0x16F92, 0x16FE4, 0x1BC9D, 0x1BC9E, 0x1CF00, 0x1CF01, 0x1CF02, 0x1CF03, 0x1CF04, 0x1CF05, 0x1CF06, - 0x1CF07, 0x1CF08, 0x1CF09, 0x1CF0A, 0x1CF0B, 0x1CF0C, 0x1CF0D, 0x1CF0E, 0x1CF0F, 0x1CF10, 0x1CF11, 0x1CF12, 0x1CF13, - 0x1CF14, 0x1CF15, 0x1CF16, 0x1CF17, 0x1CF18, 0x1CF19, 0x1CF1A, 0x1CF1B, 0x1CF1C, 0x1CF1D, 0x1CF1E, 0x1CF1F, 0x1CF20, - 0x1CF21, 0x1CF22, 0x1CF23, 0x1CF24, 0x1CF25, 0x1CF26, 0x1CF27, 0x1CF28, 0x1CF29, 0x1CF2A, 0x1CF2B, 0x1CF2C, 0x1CF2D, - 0x1CF30, 0x1CF31, 0x1CF32, 0x1CF33, 0x1CF34, 0x1CF35, 0x1CF36, 0x1CF37, 0x1CF38, 0x1CF39, 0x1CF3A, 0x1CF3B, 0x1CF3C, - 0x1CF3D, 0x1CF3E, 0x1CF3F, 0x1CF40, 0x1CF41, 0x1CF42, 0x1CF43, 0x1CF44, 0x1CF45, 0x1CF46, 0x1D167, 0x1D168, 0x1D169, - 0x1D17B, 0x1D17C, 0x1D17D, 0x1D17E, 0x1D17F, 0x1D180, 0x1D181, 0x1D182, 0x1D185, 0x1D186, 0x1D187, 0x1D188, 0x1D189, - 0x1D18A, 0x1D18B, 0x1D1AA, 0x1D1AB, 0x1D1AC, 0x1D1AD, 0x1D242, 0x1D243, 0x1D244, 0x1DA00, 0x1DA01, 0x1DA02, 0x1DA03, - 0x1DA04, 0x1DA05, 0x1DA06, 0x1DA07, 0x1DA08, 0x1DA09, 0x1DA0A, 0x1DA0B, 0x1DA0C, 0x1DA0D, 0x1DA0E, 0x1DA0F, 0x1DA10, - 0x1DA11, 0x1DA12, 0x1DA13, 0x1DA14, 0x1DA15, 0x1DA16, 0x1DA17, 0x1DA18, 0x1DA19, 0x1DA1A, 0x1DA1B, 0x1DA1C, 0x1DA1D, - 0x1DA1E, 0x1DA1F, 0x1DA20, 0x1DA21, 0x1DA22, 0x1DA23, 0x1DA24, 0x1DA25, 0x1DA26, 0x1DA27, 0x1DA28, 0x1DA29, 0x1DA2A, - 0x1DA2B, 0x1DA2C, 0x1DA2D, 0x1DA2E, 0x1DA2F, 0x1DA30, 0x1DA31, 0x1DA32, 0x1DA33, 0x1DA34, 0x1DA35, 0x1DA36, 0x1DA3B, - 0x1DA3C, 0x1DA3D, 0x1DA3E, 0x1DA3F, 0x1DA40, 0x1DA41, 0x1DA42, 0x1DA43, 0x1DA44, 0x1DA45, 0x1DA46, 0x1DA47, 0x1DA48, - 0x1DA49, 0x1DA4A, 0x1DA4B, 0x1DA4C, 0x1DA4D, 0x1DA4E, 0x1DA4F, 0x1DA50, 0x1DA51, 0x1DA52, 0x1DA53, 0x1DA54, 0x1DA55, - 0x1DA56, 0x1DA57, 0x1DA58, 0x1DA59, 0x1DA5A, 0x1DA5B, 0x1DA5C, 0x1DA5D, 0x1DA5E, 0x1DA5F, 0x1DA60, 0x1DA61, 0x1DA62, - 0x1DA63, 0x1DA64, 0x1DA65, 0x1DA66, 0x1DA67, 0x1DA68, 0x1DA69, 0x1DA6A, 0x1DA6B, 0x1DA6C, 0x1DA75, 0x1DA84, 0x1DA9B, - 0x1DA9C, 0x1DA9D, 0x1DA9E, 0x1DA9F, 0x1DAA1, 0x1DAA2, 0x1DAA3, 0x1DAA4, 0x1DAA5, 0x1DAA6, 0x1DAA7, 0x1DAA8, 0x1DAA9, - 0x1DAAA, 0x1DAAB, 0x1DAAC, 0x1DAAD, 0x1DAAE, 0x1DAAF, 0x1E000, 0x1E001, 0x1E002, 0x1E003, 0x1E004, 0x1E005, 0x1E006, - 0x1E008, 0x1E009, 0x1E00A, 0x1E00B, 0x1E00C, 0x1E00D, 0x1E00E, 0x1E00F, 0x1E010, 0x1E011, 0x1E012, 0x1E013, 0x1E014, - 0x1E015, 0x1E016, 0x1E017, 0x1E018, 0x1E01B, 0x1E01C, 0x1E01D, 0x1E01E, 0x1E01F, 0x1E020, 0x1E021, 0x1E023, 0x1E024, - 0x1E026, 0x1E027, 0x1E028, 0x1E029, 0x1E02A, 0x1E08F, 0x1E130, 0x1E131, 0x1E132, 0x1E133, 0x1E134, 0x1E135, 0x1E136, - 0x1E2AE, 0x1E2EC, 0x1E2ED, 0x1E2EE, 0x1E2EF, 0x1E4EC, 0x1E4ED, 0x1E4EE, 0x1E4EF, 0x1E5EE, 0x1E5EF, 0x1E8D0, 0x1E8D1, - 0x1E8D2, 0x1E8D3, 0x1E8D4, 0x1E8D5, 0x1E8D6, 0x1E944, 0x1E945, 0x1E946, 0x1E947, 0x1E948, 0x1E949, 0x1E94A, 0xE0100, - 0xE0101, 0xE0102, 0xE0103, 0xE0104, 0xE0105, 0xE0106, 0xE0107, 0xE0108, 0xE0109, 0xE010A, 0xE010B, 0xE010C, 0xE010D, - 0xE010E, 0xE010F, 0xE0110, 0xE0111, 0xE0112, 0xE0113, 0xE0114, 0xE0115, 0xE0116, 0xE0117, 0xE0118, 0xE0119, 0xE011A, - 0xE011B, 0xE011C, 0xE011D, 0xE011E, 0xE011F, 0xE0120, 0xE0121, 0xE0122, 0xE0123, 0xE0124, 0xE0125, 0xE0126, 0xE0127, - 0xE0128, 0xE0129, 0xE012A, 0xE012B, 0xE012C, 0xE012D, 0xE012E, 0xE012F, 0xE0130, 0xE0131, 0xE0132, 0xE0133, 0xE0134, - 0xE0135, 0xE0136, 0xE0137, 0xE0138, 0xE0139, 0xE013A, 0xE013B, 0xE013C, 0xE013D, 0xE013E, 0xE013F, 0xE0140, 0xE0141, - 0xE0142, 0xE0143, 0xE0144, 0xE0145, 0xE0146, 0xE0147, 0xE0148, 0xE0149, 0xE014A, 0xE014B, 0xE014C, 0xE014D, 0xE014E, - 0xE014F, 0xE0150, 0xE0151, 0xE0152, 0xE0153, 0xE0154, 0xE0155, 0xE0156, 0xE0157, 0xE0158, 0xE0159, 0xE015A, 0xE015B, - 0xE015C, 0xE015D, 0xE015E, 0xE015F, 0xE0160, 0xE0161, 0xE0162, 0xE0163, 0xE0164, 0xE0165, 0xE0166, 0xE0167, 0xE0168, - 0xE0169, 0xE016A, 0xE016B, 0xE016C, 0xE016D, 0xE016E, 0xE016F, 0xE0170, 0xE0171, 0xE0172, 0xE0173, 0xE0174, 0xE0175, - 0xE0176, 0xE0177, 0xE0178, 0xE0179, 0xE017A, 0xE017B, 0xE017C, 0xE017D, 0xE017E, 0xE017F, 0xE0180, 0xE0181, 0xE0182, - 0xE0183, 0xE0184, 0xE0185, 0xE0186, 0xE0187, 0xE0188, 0xE0189, 0xE018A, 0xE018B, 0xE018C, 0xE018D, 0xE018E, 0xE018F, - 0xE0190, 0xE0191, 0xE0192, 0xE0193, 0xE0194, 0xE0195, 0xE0196, 0xE0197, 0xE0198, 0xE0199, 0xE019A, 0xE019B, 0xE019C, - 0xE019D, 0xE019E, 0xE019F, 0xE01A0, 0xE01A1, 0xE01A2, 0xE01A3, 0xE01A4, 0xE01A5, 0xE01A6, 0xE01A7, 0xE01A8, 0xE01A9, - 0xE01AA, 0xE01AB, 0xE01AC, 0xE01AD, 0xE01AE, 0xE01AF, 0xE01B0, 0xE01B1, 0xE01B2, 0xE01B3, 0xE01B4, 0xE01B5, 0xE01B6, - 0xE01B7, 0xE01B8, 0xE01B9, 0xE01BA, 0xE01BB, 0xE01BC, 0xE01BD, 0xE01BE, 0xE01BF, 0xE01C0, 0xE01C1, 0xE01C2, 0xE01C3, - 0xE01C4, 0xE01C5, 0xE01C6, 0xE01C7, 0xE01C8, 0xE01C9, 0xE01CA, 0xE01CB, 0xE01CC, 0xE01CD, 0xE01CE, 0xE01CF, 0xE01D0, - 0xE01D1, 0xE01D2, 0xE01D3, 0xE01D4, 0xE01D5, 0xE01D6, 0xE01D7, 0xE01D8, 0xE01D9, 0xE01DA, 0xE01DB, 0xE01DC, 0xE01DD, - 0xE01DE, 0xE01DF, 0xE01E0, 0xE01E1, 0xE01E2, 0xE01E3, 0xE01E4, 0xE01E5, 0xE01E6, 0xE01E7, 0xE01E8, 0xE01E9, 0xE01EA, - 0xE01EB, 0xE01EC, 0xE01ED, 0xE01EE, 0xE01EF - /* END: COMBINING CHAR TABLE */ -}; - -static const unsigned long combiningCharTableSize = sizeof(combiningCharTable) / sizeof(combiningCharTable[0]); - -static bool isCombiningChar(unsigned long cp) { - for (size_t i = 0; i < combiningCharTableSize; i++) { - auto code = combiningCharTable[i]; - if (code > cp) { - return false; - } - if (code == cp) { - return true; - } - } - return false; -} - -/* Get length of previous grapheme */ -static size_t defaultPrevCharLen(const char * buf, size_t /*buf_len*/, size_t pos, size_t * col_len) { - size_t end = pos; - while (pos > 0) { - size_t len = prevUtf8CodePointLen(buf, pos); - pos -= len; - int cp; - utf8BytesToCodePoint(buf + pos, len, &cp); - if (!isCombiningChar(cp)) { - if (col_len != NULL) { - *col_len = isWideChar(cp) ? 2 : 1; - } - return end - pos; - } - } - /* NOTREACHED */ - return 0; -} - -/* Get length of next grapheme */ -static size_t defaultNextCharLen(const char * buf, size_t buf_len, size_t pos, size_t * col_len) { - size_t beg = pos; - int cp; - size_t len = utf8BytesToCodePoint(buf + pos, buf_len - pos, &cp); - if (isCombiningChar(cp)) { - /* NOTREACHED */ - return 0; - } - if (col_len != NULL) { - *col_len = isWideChar(cp) ? 2 : 1; - } - pos += len; - while (pos < buf_len) { - int cp; - len = utf8BytesToCodePoint(buf + pos, buf_len - pos, &cp); - if (!isCombiningChar(cp)) { - return pos - beg; - } - pos += len; - } - return pos - beg; -} - -/* Read a Unicode from file. */ -static size_t defaultReadCode(int fd, char * buf, size_t buf_len, int * cp) { - if (buf_len < 1) { - return -1; - } - size_t nread = read(fd, &buf[0], 1); - if (nread <= 0) { - return nread; - } - - unsigned char byte = buf[0]; - if ((byte & 0x80) == 0) { - ; - } else if ((byte & 0xE0) == 0xC0) { - if (buf_len < 2) { - return -1; - } - nread = read(fd, &buf[1], 1); - if (nread <= 0) { - return nread; - } - } else if ((byte & 0xF0) == 0xE0) { - if (buf_len < 3) { - return -1; - } - nread = read(fd, &buf[1], 2); - if (nread <= 0) { - return nread; - } - } else if ((byte & 0xF8) == 0xF0) { - if (buf_len < 3) { - return -1; - } - nread = read(fd, &buf[1], 3); - if (nread <= 0) { - return nread; - } - } else { - return -1; - } - - return utf8BytesToCodePoint(buf, buf_len, cp); -} - -/* Set default encoding functions */ -static linenoisePrevCharLen * prevCharLen = defaultPrevCharLen; -static linenoiseNextCharLen * nextCharLen = defaultNextCharLen; -static linenoiseReadCode * readCode = defaultReadCode; - -/* Set used defined encoding functions */ -void linenoiseSetEncodingFunctions(linenoisePrevCharLen * prevCharLenFunc, linenoiseNextCharLen * nextCharLenFunc, - linenoiseReadCode * readCodeFunc) { - prevCharLen = prevCharLenFunc; - nextCharLen = nextCharLenFunc; - readCode = readCodeFunc; -} - -/* ======================= Low level terminal handling ====================== */ - -/* Enable "mask mode". When it is enabled, instead of the input that - * the user is typing, the terminal will just display a corresponding - * number of asterisks, like "****". This is useful for passwords and other - * secrets that should not be displayed. */ -void linenoiseMaskModeEnable(void) { - maskmode = 1; -} - -/* Disable mask mode. */ -void linenoiseMaskModeDisable(void) { - maskmode = 0; -} - -/* Set if to use or not the multi line mode. */ -void linenoiseSetMultiLine(int ml) { - mlmode = ml; -} - -/* Return true if the terminal name is in the list of terminals we know are - * not able to understand basic escape sequences. */ -static int isUnsupportedTerm(void) { - char *term = getenv("TERM"); - if (term == NULL) return 0; - for (size_t j = 0; j < unsupported_term.size(); ++j) { - if (!strcasecmp(term, unsupported_term[j])) { - return 1; - } - } - return 0; -} - -/* Raw mode: 1960 magic shit. */ -static int enableRawMode(int fd) { - struct termios raw; - - if (!isatty(STDIN_FILENO)) goto fatal; - if (!atexit_registered) { - atexit(linenoiseAtExit); - atexit_registered = 1; - } - if (tcgetattr(fd,&orig_termios) == -1) goto fatal; - - raw = orig_termios; /* modify the original mode */ - /* input modes: no break, no CR to NL, no parity check, no strip char, - * no start/stop output control. */ - raw.c_iflag &= ~(BRKINT | ICRNL | INPCK | ISTRIP | IXON); - /* output modes - disable post processing */ - raw.c_oflag &= ~(OPOST); - /* control modes - set 8 bit chars */ - raw.c_cflag |= (CS8); - /* local modes - choing off, canonical off, no extended functions, - * no signal chars (^Z,^C) */ - raw.c_lflag &= ~(ECHO | ICANON | IEXTEN | ISIG); - /* control chars - set return condition: min number of bytes and timer. - * We want read to return every single byte, without timeout. */ - raw.c_cc[VMIN] = 1; raw.c_cc[VTIME] = 0; /* 1 byte, no timer */ - - /* put terminal in raw mode after flushing */ - if (tcsetattr(fd,TCSAFLUSH,&raw) < 0) goto fatal; - rawmode = 1; - return 0; - -fatal: - errno = ENOTTY; - return -1; -} - -static void disableRawMode(int fd) { - /* Don't even check the return value as it's too late. */ - if (rawmode && tcsetattr(fd,TCSAFLUSH,&orig_termios) != -1) - rawmode = 0; -} - -/* Use the ESC [6n escape sequence to query the horizontal cursor position - * and return it. On error -1 is returned, on success the position of the - * cursor. */ -static int getCursorPosition(int ifd, int ofd) { - char buf[32]; - int cols, rows; - unsigned int i = 0; - - /* Report cursor location */ - if (write(ofd, "\x1b[6n", 4) != 4) return -1; - - /* Read the response: ESC [ rows ; cols R */ - while (i < sizeof(buf)-1) { - if (read(ifd,buf+i,1) != 1) break; - if (buf[i] == 'R') break; - i++; - } - buf[i] = '\0'; - - /* Parse it. */ - if (buf[0] != ESC || buf[1] != '[') return -1; - if (sscanf(buf+2,"%d;%d",&rows,&cols) != 2) return -1; - return cols; -} - -/* Try to get the number of columns in the current terminal, or assume 80 - * if it fails. */ -static int getColumns(int ifd, int ofd) { - struct winsize ws; - - if (ioctl(1, TIOCGWINSZ, &ws) == -1 || ws.ws_col == 0) { - /* ioctl() failed. Try to query the terminal itself. */ - int start, cols; - - /* Get the initial position so we can restore it later. */ - start = getCursorPosition(ifd,ofd); - if (start == -1) goto failed; - - /* Go to right margin and get position. */ - if (write(ofd,"\x1b[999C",6) != 6) goto failed; - cols = getCursorPosition(ifd,ofd); - if (cols == -1) goto failed; - - /* Restore position. */ - if (cols > start) { - char seq[32]; - snprintf(seq,32,"\x1b[%dD",cols-start); - if (write(ofd,seq,strlen(seq)) == -1) { - /* Can't recover... */ - } - } - return cols; - } else { - return ws.ws_col; - } - -failed: - return 80; -} - -/* Clear the screen. Used to handle ctrl+l */ -void linenoiseClearScreen(void) { - if (write(STDOUT_FILENO,"\x1b[H\x1b[2J",7) <= 0) { - /* nothing to do, just to avoid warning. */ - } -} - -/* Beep, used for completion when there is nothing to complete or when all - * the choices were already shown. */ -static void linenoiseBeep(void) { - fprintf(stderr, "\x7"); - fflush(stderr); -} - -/* Called by completeLine() and linenoiseShow() to render the current - * edited line with the proposed completion. If the current completion table - * is already available, it is passed as second argument, otherwise the - * function will use the callback to obtain it. - * - * Flags are the same as refreshLine*(), that is REFRESH_* macros. */ -static void refreshLineWithCompletion(struct linenoiseState *ls, linenoiseCompletions *lc, int flags) { - /* Obtain the table of completions if the caller didn't provide one. */ - linenoiseCompletions ctable; - if (lc == NULL) { - completionCallback(ls->buf, &ctable); - lc = &ctable; - } - - /* Show the edited line with completion if possible, or just refresh. */ - if (ls->completion_idx < lc->len) { - struct linenoiseState saved = *ls; - ls->len = ls->pos = strlen(lc->cvec[ls->completion_idx]); - ls->buf = lc->cvec[ls->completion_idx]; - refreshLineWithFlags(ls, flags); - ls->len = saved.len; - ls->pos = saved.pos; - ls->buf = saved.buf; - } else { - refreshLineWithFlags(ls, flags); - } - - if (lc == &ctable) { - ctable.to_free = false; - } -} - -enum ESC_TYPE { ESC_NULL = 0, ESC_DELETE, ESC_UP, ESC_DOWN, ESC_RIGHT, ESC_LEFT, ESC_HOME, ESC_END }; - -static ESC_TYPE readEscapeSequence(struct linenoiseState * l) { - /* Check if the file input has additional data. */ - struct pollfd pfd; - pfd.fd = l->ifd; - pfd.events = POLLIN; - - auto ret = poll(&pfd, 1, 1); // 1 millisecond timeout - if (ret <= 0) { // -1: error, 0: timeout - return ESC_NULL; - } - - /* Read the next two bytes representing the escape sequence. - * Use two calls to handle slow terminals returning the two - * chars at different times. */ - char seq[3]; - if (read(l->ifd, seq, 1) == -1) { - return ESC_NULL; - } - if (read(l->ifd, seq + 1, 1) == -1) { - return ESC_NULL; - } - - /* ESC [ sequences. */ - if (seq[0] == '[') { - if (seq[1] >= '0' && seq[1] <= '9') { - /* Extended escape, read additional byte. */ - if (read(l->ifd, seq + 2, 1) == -1) { - return ESC_NULL; - } - if (seq[2] == '~') { - switch (seq[1]) { - case '3': - return ESC_DELETE; - } - } - } else { - switch (seq[1]) { - case 'A': - return ESC_UP; - case 'B': - return ESC_DOWN; - case 'C': - return ESC_RIGHT; - case 'D': - return ESC_LEFT; - case 'H': - return ESC_HOME; - case 'F': - return ESC_END; - } - } - } - - /* ESC O sequences. */ - else if (seq[0] == 'O') { - switch (seq[1]) { - case 'H': - return ESC_HOME; - case 'F': - return ESC_END; - } - } - return ESC_NULL; -} - -/* This is an helper function for linenoiseEdit*() and is called when the - * user types the key in order to complete the string currently in the - * input. - * - * The state of the editing is encapsulated into the pointed linenoiseState - * structure as described in the structure definition. - * - * If the function returns non-zero, the caller should handle the - * returned value as a byte read from the standard input, and process - * it as usually: this basically means that the function may return a byte - * read from the terminal but not processed. Otherwise, if zero is returned, - * the input was consumed by the completeLine() function to navigate the - * possible completions, and the caller should read for the next characters - * from stdin. */ -static int completeLine(struct linenoiseState * ls, int keypressed, ESC_TYPE esc_type) { - linenoiseCompletions lc; - int nwritten; - char c = keypressed; - - completionCallback(ls->buf, &lc); - if (lc.len == 0) { - linenoiseBeep(); - ls->in_completion = 0; - } else { - if (c == TAB) { - if (ls->in_completion == 0) { - ls->in_completion = 1; - ls->completion_idx = 0; - } else { - ls->completion_idx = (ls->completion_idx + 1) % (lc.len + 1); - if (ls->completion_idx == lc.len) { - linenoiseBeep(); - } - } - c = 0; - } else if (c == ESC && esc_type == ESC_NULL) { - /* Re-show original buffer */ - if (ls->completion_idx < lc.len) { - refreshLine(ls); - } - ls->in_completion = 0; - c = 0; - } else { - /* Update buffer and return */ - if (ls->completion_idx < lc.len) { - nwritten = snprintf(ls->buf, ls->buflen, "%s", lc.cvec[ls->completion_idx]); - ls->len = ls->pos = nwritten; - } - ls->in_completion = 0; - } - - /* Show completion or original buffer */ - if (ls->in_completion && ls->completion_idx < lc.len) { - refreshLineWithCompletion(ls, &lc, REFRESH_ALL); - } else { - refreshLine(ls); - } - } - - return c; /* Return last read character */ -} - -/* Register a callback function to be called for tab-completion. */ -void linenoiseSetCompletionCallback(linenoiseCompletionCallback *fn) { - completionCallback = fn; -} - -/* Register a hits function to be called to show hits to the user at the - * right of the prompt. */ -void linenoiseSetHintsCallback(linenoiseHintsCallback *fn) { - hintsCallback = fn; -} - -/* Register a function to free the hints returned by the hints callback - * registered with linenoiseSetHintsCallback(). */ -void linenoiseSetFreeHintsCallback(linenoiseFreeHintsCallback *fn) { - freeHintsCallback = fn; -} - -/* This function is used by the callback function registered by the user - * in order to add completion options given the input string when the - * user typed . See the example.c source code for a very easy to - * understand example. */ -void linenoiseAddCompletion(linenoiseCompletions *lc, const char *str) { - const size_t len = strlen(str); - auto copy = std::make_unique(len + 1); - if (!copy) { - return; - } - - memcpy(copy.get(), str, len + 1); - char ** cvec = static_cast(std::realloc(lc->cvec, sizeof(char *) * (lc->len + 1))); - if (cvec == nullptr) { - return; - } - - lc->cvec = cvec; - lc->cvec[lc->len++] = copy.release(); -} - -/* Get column length from begining of buffer to current byte position */ -static size_t columnPos(const char * buf, size_t buf_len, size_t pos) { - size_t ret = 0; - size_t off = 0; - while (off < pos) { - size_t col_len; - size_t len = nextCharLen(buf, buf_len, off, &col_len); - off += len; - ret += col_len; - } - return ret; -} - -/* Helper of refreshSingleLine() and refreshMultiLine() to show hints - * to the right of the prompt. */ -static void refreshShowHints(std::string & ab, struct linenoiseState * l, int pcollen) { - char seq[64]; - size_t collen = pcollen + columnPos(l->buf, l->len, l->len); - if (hintsCallback && collen < l->cols) { - int color = -1, bold = 0; - const char *hint = hintsCallback(l->buf,&color,&bold); - if (hint) { - int hintlen = strlen(hint); - int hintmaxlen = l->cols - collen; - if (hintlen > hintmaxlen) hintlen = hintmaxlen; - if (bold == 1 && color == -1) color = 37; - if (color != -1 || bold != 0) - snprintf(seq,64,"\033[%d;%d;49m",bold,color); - else - seq[0] = '\0'; - ab.append(seq); - ab.append(hint, hintlen); - if (color != -1 || bold != 0) - ab.append("\033[0m"); - - /* Call the function to free the hint returned. */ - if (freeHintsCallback) freeHintsCallback(hint); - } - } -} - -/* Check if text is an ANSI escape sequence */ -static int isAnsiEscape(const char * buf, size_t buf_len, size_t * len) { - if (buf_len > 2 && !memcmp("\033[", buf, 2)) { - size_t off = 2; - while (off < buf_len) { - switch (buf[off++]) { - case 'A': - case 'B': - case 'C': - case 'D': - case 'E': - case 'F': - case 'G': - case 'H': - case 'J': - case 'K': - case 'S': - case 'T': - case 'f': - case 'm': - *len = off; - return 1; - } - } - } - return 0; -} - -/* Get column length of prompt text */ -static size_t promptTextColumnLen(const char * prompt, size_t plen) { - char buf[LINENOISE_MAX_LINE]; - size_t buf_len = 0; - size_t off = 0; - while (off < plen) { - size_t len; - if (isAnsiEscape(prompt + off, plen - off, &len)) { - off += len; - continue; - } - buf[buf_len++] = prompt[off++]; - } - return columnPos(buf, buf_len, buf_len); -} - -/* Single line low level line refresh. - * - * Rewrite the currently edited line accordingly to the buffer content, - * cursor position, and number of columns of the terminal. - * - * Flags is REFRESH_* macros. The function can just remove the old - * prompt, just write it, or both. */ -static void refreshSingleLine(struct linenoiseState *l, int flags) { - char seq[64]; - size_t pcollen = promptTextColumnLen(l->prompt, strlen(l->prompt)); - int fd = l->ofd; - char *buf = l->buf; - size_t len = l->len; - size_t pos = l->pos; - std::string ab; - - while ((pcollen + columnPos(buf, len, pos)) >= l->cols) { - int chlen = nextCharLen(buf, len, 0, NULL); - buf += chlen; - len -= chlen; - pos -= chlen; - } - while (pcollen + columnPos(buf, len, len) > l->cols) { - len -= prevCharLen(buf, len, len, NULL); - } - - /* Cursor to left edge */ - snprintf(seq,sizeof(seq),"\r"); - ab.append(seq); - - if (flags & REFRESH_WRITE) { - /* Write the prompt and the current buffer content */ - ab.append(l->prompt); - if (maskmode == 1) { - while (len--) { - ab.append("*"); - } - } else { - ab.append(buf, len); - } - /* Show hits if any. */ - refreshShowHints(ab, l, pcollen); - } - - /* Erase to right */ - snprintf(seq,sizeof(seq),"\x1b[0K"); - ab.append(seq); - if (flags & REFRESH_WRITE) { - /* Move cursor to original position. */ - snprintf(seq, sizeof(seq), "\r\x1b[%dC", (int) (columnPos(buf, len, pos) + pcollen)); - ab.append(seq); - } - - (void) !write(fd, ab.c_str(), ab.size()); /* Can't recover from write error. */ -} - -/* Get column length from begining of buffer to current byte position for multiline mode*/ -static size_t columnPosForMultiLine(const char * buf, size_t buf_len, size_t pos, size_t cols, size_t ini_pos) { - size_t ret = 0; - size_t colwid = ini_pos; - - size_t off = 0; - while (off < buf_len) { - size_t col_len; - size_t len = nextCharLen(buf, buf_len, off, &col_len); - - int dif = (int) (colwid + col_len) - (int) cols; - if (dif > 0) { - ret += dif; - colwid = col_len; - } else if (dif == 0) { - colwid = 0; - } else { - colwid += col_len; - } - - if (off >= pos) { - break; - } - off += len; - ret += col_len; - } - - return ret; -} - -/* Multi line low level line refresh. - * - * Rewrite the currently edited line accordingly to the buffer content, - * cursor position, and number of columns of the terminal. - * - * Flags is REFRESH_* macros. The function can just remove the old - * prompt, just write it, or both. */ -static void refreshMultiLine(struct linenoiseState *l, int flags) { - char seq[64]; - size_t pcollen = promptTextColumnLen(l->prompt, strlen(l->prompt)); - int colpos = columnPosForMultiLine(l->buf, l->len, l->len, l->cols, pcollen); - int colpos2; /* cursor column position. */ - int rows = (pcollen + colpos + l->cols - 1) / l->cols; /* rows used by current buf. */ - int rpos = (pcollen + l->oldcolpos + l->cols) / l->cols; /* cursor relative row. */ - int rpos2; /* rpos after refresh. */ - int col; /* column position, zero-based. */ - int old_rows = l->oldrows; - int fd = l->ofd, j; - std::string ab; - l->oldrows = rows; - - /* First step: clear all the lines used before. To do so start by - * going to the last row. */ - if (flags & REFRESH_CLEAN) { - if (old_rows - rpos > 0) { - snprintf(seq,64,"\x1b[%dB", old_rows-rpos); - ab.append(seq); - } - - /* Now for every row clear it, go up. */ - for (j = 0; j < old_rows - 1; j++) { - snprintf(seq,64,"\r\x1b[0K\x1b[1A"); - ab.append(seq); - } - } - - if (flags & REFRESH_ALL) { - /* Clean the top line. */ - snprintf(seq,64,"\r\x1b[0K"); - ab.append(seq); - } - - /* Get column length to cursor position */ - colpos2 = columnPosForMultiLine(l->buf, l->len, l->pos, l->cols, pcollen); - - if (flags & REFRESH_WRITE) { - /* Write the prompt and the current buffer content */ - ab.append(l->prompt); - if (maskmode == 1) { - for (unsigned int i = 0; i < l->len; ++i) { - ab.append("*"); - } - } else { - ab.append(l->buf, l->len); - } - - /* Show hits if any. */ - refreshShowHints(ab, l, pcollen); - - /* If we are at the very end of the screen with our prompt, we need to - * emit a newline and move the prompt to the first column. */ - if (l->pos && l->pos == l->len && (colpos2 + pcollen) % l->cols == 0) { - ab.append("\n"); - snprintf(seq,64,"\r"); - ab.append(seq); - rows++; - if (rows > (int)l->oldrows) l->oldrows = rows; - } - - /* Move cursor to right position. */ - rpos2 = (pcollen + colpos2 + l->cols) / l->cols; /* Current cursor relative row */ - - /* Go up till we reach the expected position. */ - if (rows - rpos2 > 0) { - snprintf(seq,64,"\x1b[%dA", rows-rpos2); - ab.append(seq); - } - - /* Set column. */ - col = (pcollen + colpos2) % l->cols; - if (col) - snprintf(seq,64,"\r\x1b[%dC", col); - else - snprintf(seq,64,"\r"); - ab.append(seq); - } - - l->oldcolpos = colpos2; - - (void) !write(fd, ab.c_str(), ab.size()); /* Can't recover from write error. */ -} - -/* Calls the two low level functions refreshSingleLine() or - * refreshMultiLine() according to the selected mode. */ -static void refreshLineWithFlags(struct linenoiseState *l, int flags) { - if (mlmode) - refreshMultiLine(l,flags); - else - refreshSingleLine(l,flags); -} - -/* Utility function to avoid specifying REFRESH_ALL all the times. */ -static void refreshLine(struct linenoiseState *l) { - refreshLineWithFlags(l,REFRESH_ALL); -} - -/* Hide the current line, when using the multiplexing API. */ -void linenoiseHide(struct linenoiseState *l) { - if (mlmode) - refreshMultiLine(l,REFRESH_CLEAN); - else - refreshSingleLine(l,REFRESH_CLEAN); -} - -/* Show the current line, when using the multiplexing API. */ -void linenoiseShow(struct linenoiseState *l) { - if (l->in_completion) { - refreshLineWithCompletion(l,NULL,REFRESH_WRITE); - } else { - refreshLineWithFlags(l,REFRESH_WRITE); - } -} - -/* Insert the character 'c' at cursor current position. - * - * On error writing to the terminal -1 is returned, otherwise 0. */ -static int linenoiseEditInsert(struct linenoiseState * l, const char * cbuf, int clen) { - if (l->len + clen <= l->buflen) { - if (l->len == l->pos) { - memcpy(&l->buf[l->pos], cbuf, clen); - l->pos += clen; - l->len += clen; - ; - l->buf[l->len] = '\0'; - if ((!mlmode && promptTextColumnLen(l->prompt, l->plen) + columnPos(l->buf, l->len, l->len) < l->cols && - !hintsCallback)) { - /* Avoid a full update of the line in the - * trivial case. */ - if (maskmode == 1) { - static const char d = '*'; - if (write(l->ofd, &d, 1) == -1) { - return -1; - } - } else { - if (write(l->ofd, cbuf, clen) == -1) { - return -1; - } - } - } else { - refreshLine(l); - } - } else { - memmove(l->buf + l->pos + clen, l->buf + l->pos, l->len - l->pos); - memcpy(&l->buf[l->pos], cbuf, clen); - l->pos += clen; - l->len += clen; - l->buf[l->len] = '\0'; - refreshLine(l); - } - } - return 0; -} - -/* Move cursor on the left. */ -static void linenoiseEditMoveLeft(struct linenoiseState * l) { - if (l->pos > 0) { - l->pos -= prevCharLen(l->buf, l->len, l->pos, NULL); - refreshLine(l); - } -} - -/* Move cursor on the right. */ -static void linenoiseEditMoveRight(struct linenoiseState * l) { - if (l->pos != l->len) { - l->pos += nextCharLen(l->buf, l->len, l->pos, NULL); - refreshLine(l); - } -} - -/* Move cursor to the start of the line. */ -static void linenoiseEditMoveHome(struct linenoiseState * l) { - if (l->pos != 0) { - l->pos = 0; - refreshLine(l); - } -} - -/* Move cursor to the end of the line. */ -static void linenoiseEditMoveEnd(struct linenoiseState * l) { - if (l->pos != l->len) { - l->pos = l->len; - refreshLine(l); - } -} - -/* Substitute the currently edited line with the next or previous history - * entry as specified by 'dir'. */ -#define LINENOISE_HISTORY_NEXT 0 -#define LINENOISE_HISTORY_PREV 1 - -static void linenoiseEditHistoryNext(struct linenoiseState * l, int dir) { - if (history_len > 1) { - /* Update the current history entry before to - * overwrite it with the next one. */ - free(history[history_len - 1 - l->history_index]); - history[history_len - 1 - l->history_index] = strdup(l->buf); - /* Show the new entry */ - l->history_index += (dir == LINENOISE_HISTORY_PREV) ? 1 : -1; - if (l->history_index < 0) { - l->history_index = 0; - return; - } else if (l->history_index >= history_len) { - l->history_index = history_len-1; - return; - } - strncpy(l->buf,history[history_len - 1 - l->history_index],l->buflen); - l->buf[l->buflen-1] = '\0'; - l->len = l->pos = strlen(l->buf); - refreshLine(l); - } -} - -/* Delete the character at the right of the cursor without altering the cursor - * position. Basically this is what happens with the "Delete" keyboard key. */ -static void linenoiseEditDelete(struct linenoiseState * l) { - if (l->len > 0 && l->pos < l->len) { - int chlen = nextCharLen(l->buf, l->len, l->pos, NULL); - memmove(l->buf + l->pos, l->buf + l->pos + chlen, l->len - l->pos - chlen); - l->len -= chlen; - l->buf[l->len] = '\0'; - refreshLine(l); - } -} - -/* Backspace implementation. */ -static void linenoiseEditBackspace(struct linenoiseState * l) { - if (l->pos > 0 && l->len > 0) { - int chlen = prevCharLen(l->buf, l->len, l->pos, NULL); - memmove(l->buf + l->pos - chlen, l->buf + l->pos, l->len - l->pos); - l->pos -= chlen; - l->len -= chlen; - l->buf[l->len] = '\0'; - refreshLine(l); - } -} - -/* Delete the previous word, maintaining the cursor at the start of the - * current word. */ -static void linenoiseEditDeletePrevWord(struct linenoiseState * l) { - size_t old_pos = l->pos; - size_t diff; - - while (l->pos > 0 && l->buf[l->pos-1] == ' ') - l->pos--; - while (l->pos > 0 && l->buf[l->pos-1] != ' ') - l->pos--; - diff = old_pos - l->pos; - memmove(l->buf+l->pos,l->buf+old_pos,l->len-old_pos+1); - l->len -= diff; - refreshLine(l); -} - -/* This function is part of the multiplexed API of Linenoise, that is used - * in order to implement the blocking variant of the API but can also be - * called by the user directly in an event driven program. It will: - * - * 1. Initialize the linenoise state passed by the user. - * 2. Put the terminal in RAW mode. - * 3. Show the prompt. - * 4. Return control to the user, that will have to call linenoiseEditFeed() - * each time there is some data arriving in the standard input. - * - * The user can also call linenoiseEditHide() and linenoiseEditShow() if it - * is required to show some input arriving asynchronously, without mixing - * it with the currently edited line. - * - * When linenoiseEditFeed() returns non-NULL, the user finished with the - * line editing session (pressed enter CTRL-D/C): in this case the caller - * needs to call linenoiseEditStop() to put back the terminal in normal - * mode. This will not destroy the buffer, as long as the linenoiseState - * is still valid in the context of the caller. - * - * The function returns 0 on success, or -1 if writing to standard output - * fails. If stdin_fd or stdout_fd are set to -1, the default is to use - * STDIN_FILENO and STDOUT_FILENO. - */ -int linenoiseEditStart(struct linenoiseState *l, int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt) { - /* Populate the linenoise state that we pass to functions implementing - * specific editing functionalities. */ - l->in_completion = 0; - l->ifd = stdin_fd != -1 ? stdin_fd : STDIN_FILENO; - l->ofd = stdout_fd != -1 ? stdout_fd : STDOUT_FILENO; - l->buf = buf; - l->buflen = buflen; - l->prompt = prompt; - l->plen = strlen(prompt); - l->oldcolpos = l->pos = 0; - l->len = 0; - - /* Enter raw mode. */ - if (enableRawMode(l->ifd) == -1) return -1; - - l->cols = getColumns(stdin_fd, stdout_fd); - l->oldrows = 0; - l->history_index = 0; - - /* Buffer starts empty. */ - l->buf[0] = '\0'; - l->buflen--; /* Make sure there is always space for the nullterm */ - - /* If stdin is not a tty, stop here with the initialization. We - * will actually just read a line from standard input in blocking - * mode later, in linenoiseEditFeed(). */ - if (!isatty(l->ifd)) return 0; - - /* The latest history entry is always our current buffer, that - * initially is just an empty string. */ - linenoiseHistoryAdd(""); - - if (write(l->ofd,prompt,l->plen) == -1) return -1; - return 0; -} - -const char* linenoiseEditMore = "If you see this, you are misusing the API: when linenoiseEditFeed() is called, if it returns linenoiseEditMore the user is yet editing the line. See the README file for more information."; - -static const char * handleEnterKey(struct linenoiseState * l) { - --history_len; - free(history[history_len]); - if (mlmode) { - linenoiseEditMoveEnd(l); - } - if (hintsCallback) { - /* Force a refresh without hints to leave the previous - * line as the user typed it after a newline. */ - linenoiseHintsCallback * hc = hintsCallback; - hintsCallback = NULL; - refreshLine(l); - hintsCallback = hc; - } - - return strdup(l->buf); -} - -static const char * handleCtrlCKey() { - errno = EAGAIN; - return NULL; -} - -static const char * handleCtrlDKey(struct linenoiseState * l) { - if (l->len > 0) { - linenoiseEditDelete(l); - return linenoiseEditMore; - } - - --history_len; - free(history[history_len]); - errno = ENOENT; - return NULL; -} - -static void handleCtrlTKey(struct linenoiseState * l) { - if (l->pos > 0 && l->pos < l->len) { - auto prev_chlen = prevCharLen(l->buf, l->len, l->pos, NULL); - auto curr_chlen = nextCharLen(l->buf, l->len, l->pos, NULL); - - std::string prev_char(prev_chlen, 0); - memcpy(prev_char.data(), l->buf + l->pos - prev_chlen, prev_chlen); - memmove(l->buf + l->pos - prev_chlen, l->buf + l->pos, curr_chlen); - memmove(l->buf + l->pos - prev_chlen + curr_chlen, prev_char.data(), prev_chlen); - - l->pos = l->pos - prev_chlen + curr_chlen; - if (l->pos + prev_chlen != l->len) { - l->pos += prev_chlen; - } - - refreshLine(l); - } -} - -static void handleEscapeSequence(struct linenoiseState * l, int esc_type) { - switch (esc_type) { - case ESC_NULL: - break; - case ESC_DELETE: - linenoiseEditDelete(l); - break; - case ESC_UP: - linenoiseEditHistoryNext(l, LINENOISE_HISTORY_PREV); - break; - case ESC_DOWN: - linenoiseEditHistoryNext(l, LINENOISE_HISTORY_NEXT); - break; - case ESC_RIGHT: - linenoiseEditMoveRight(l); - break; - case ESC_LEFT: - linenoiseEditMoveLeft(l); - break; - case ESC_HOME: - linenoiseEditMoveHome(l); - break; - case ESC_END: - linenoiseEditMoveEnd(l); - break; - } -} - -static void handleCtrlUKey(struct linenoiseState * l) { - l->buf[0] = '\0'; - l->pos = l->len = 0; - refreshLine(l); -} - -static void handleCtrlKKey(struct linenoiseState * l) { - l->buf[l->pos] = '\0'; - l->len = l->pos; - refreshLine(l); -} - -static const char * processInputCharacter(struct linenoiseState * l, int c, char * cbuf, int nread, int esc_type) { - switch (c) { - case ENTER: - return handleEnterKey(l); - case CTRL_C: - return handleCtrlCKey(); - case BACKSPACE: - case CTRL_H: - linenoiseEditBackspace(l); - break; - case CTRL_D: /* ctrl-d, remove char at right of cursor, or if the - line is empty, act as end-of-file. */ - return handleCtrlDKey(l); - case CTRL_T: - handleCtrlTKey(l); - break; - case CTRL_B: - linenoiseEditMoveLeft(l); - break; - case CTRL_F: - linenoiseEditMoveRight(l); - break; - case CTRL_P: - linenoiseEditHistoryNext(l, LINENOISE_HISTORY_PREV); - break; - case CTRL_N: - linenoiseEditHistoryNext(l, LINENOISE_HISTORY_NEXT); - break; - case ESC: - handleEscapeSequence(l, esc_type); - break; - default: - if (linenoiseEditInsert(l, cbuf, nread)) { - return NULL; - } - break; - case CTRL_U: /* Ctrl+u, delete the whole line. */ - handleCtrlUKey(l); - break; - case CTRL_K: /* Ctrl+k, delete from current to end of line. */ - handleCtrlKKey(l); - break; - case CTRL_A: /* Ctrl+a, go to the start of the line */ - linenoiseEditMoveHome(l); - break; - case CTRL_E: /* ctrl+e, go to the end of the line */ - linenoiseEditMoveEnd(l); - break; - case CTRL_L: /* ctrl+l, clear screen */ - linenoiseClearScreen(); - refreshLine(l); - break; - case CTRL_W: /* ctrl+w, delete previous word */ - linenoiseEditDeletePrevWord(l); - break; - } - return linenoiseEditMore; -} - -/* This function is part of the multiplexed API of linenoise, see the top - * comment on linenoiseEditStart() for more information. Call this function - * each time there is some data to read from the standard input file - * descriptor. In the case of blocking operations, this function can just be - * called in a loop, and block. - * - * The function returns linenoiseEditMore to signal that line editing is still - * in progress, that is, the user didn't yet pressed enter / CTRL-D. Otherwise - * the function returns the pointer to the heap-allocated buffer with the - * edited line, that the user should free with linenoiseFree(). - * - * On special conditions, NULL is returned and errno is populated: - * - * EAGAIN if the user pressed Ctrl-C - * ENOENT if the user pressed Ctrl-D - * - * Some other errno: I/O error. - */ -const char * linenoiseEditFeed(struct linenoiseState * l) { - /* Not a TTY, pass control to line reading without character count - * limits. */ - if (!isatty(l->ifd)) return linenoiseNoTTY(); - - int c; - int nread; - char cbuf[32]; - - nread = readCode(l->ifd, cbuf, sizeof(cbuf), &c); - if (nread <= 0) return NULL; - - auto esc_type = ESC_NULL; - if (c == ESC) { - esc_type = readEscapeSequence(l); - } - - /* Only autocomplete when the callback is set. It returns < 0 when - * there was an error reading from fd. Otherwise it will return the - * character that should be handled next. */ - if ((l->in_completion || c == 9) && completionCallback != NULL) { - c = completeLine(l, c, esc_type); - /* Read next character when 0 */ - if (c == 0) return linenoiseEditMore; - } - - return processInputCharacter(l, c, cbuf, nread, esc_type); -} - -/* This is part of the multiplexed linenoise API. See linenoiseEditStart() - * for more information. This function is called when linenoiseEditFeed() - * returns something different than NULL. At this point the user input - * is in the buffer, and we can restore the terminal in normal mode. */ -void linenoiseEditStop(struct linenoiseState *l) { - if (!isatty(l->ifd)) return; - disableRawMode(l->ifd); - printf("\n"); -} - -/* This just implements a blocking loop for the multiplexed API. - * In many applications that are not event-driven, we can just call - * the blocking linenoise API, wait for the user to complete the editing - * and return the buffer. */ -static const char *linenoiseBlockingEdit(int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt) -{ - struct linenoiseState l; - - /* Editing without a buffer is invalid. */ - if (buflen == 0) { - errno = EINVAL; - return NULL; - } - - linenoiseEditStart(&l,stdin_fd,stdout_fd,buf,buflen,prompt); - const char *res; - while((res = linenoiseEditFeed(&l)) == linenoiseEditMore); - linenoiseEditStop(&l); - return res; -} - -/* This special mode is used by linenoise in order to print scan codes - * on screen for debugging / development purposes. It is implemented - * by the linenoise_example program using the --keycodes option. */ -void linenoisePrintKeyCodes(void) { - char quit[4]; - - printf("Linenoise key codes debugging mode.\n" - "Press keys to see scan codes. Type 'quit' at any time to exit.\n"); - if (enableRawMode(STDIN_FILENO) == -1) return; - memset(quit,' ',4); - while(1) { - char c; - int nread; - - nread = read(STDIN_FILENO,&c,1); - if (nread <= 0) continue; - memmove(quit,quit+1,sizeof(quit)-1); /* shift string to left. */ - quit[sizeof(quit)-1] = c; /* Insert current char on the right. */ - if (memcmp(quit,"quit",sizeof(quit)) == 0) break; - - printf("'%c' %02x (%d) (type quit to exit)\n", isprint((int) c) ? c : '?', (int) c, (int) c); - printf("\r"); /* Go left edge manually, we are in raw mode. */ - fflush(stdout); - } - disableRawMode(STDIN_FILENO); -} - -/* This function is called when linenoise() is called with the standard - * input file descriptor not attached to a TTY. So for example when the - * program using linenoise is called in pipe or with a file redirected - * to its standard input. In this case, we want to be able to return the - * line regardless of its length (by default we are limited to 4k). */ -static char *linenoiseNoTTY(void) { - char *line = NULL; - size_t len = 0, maxlen = 0; - - while(1) { - if (len == maxlen) { - if (maxlen == 0) maxlen = 16; - maxlen *= 2; - char *oldval = line; - line = (char*) realloc(line,maxlen); - if (line == NULL) { - if (oldval) free(oldval); - return NULL; - } - } - int c = fgetc(stdin); - if (c == EOF || c == '\n') { - if (c == EOF && len == 0) { - free(line); - return NULL; - } else { - line[len] = '\0'; - return line; - } - } else { - line[len] = c; - len++; - } - } -} - -/* The high level function that is the main API of the linenoise library. - * This function checks if the terminal has basic capabilities, just checking - * for a blacklist of stupid terminals, and later either calls the line - * editing function or uses dummy fgets() so that you will be able to type - * something even in the most desperate of the conditions. */ -const char *linenoise(const char *prompt) { - char buf[LINENOISE_MAX_LINE]; - - if (!isatty(STDIN_FILENO)) { - /* Not a tty: read from file / pipe. In this mode we don't want any - * limit to the line size, so we call a function to handle that. */ - return linenoiseNoTTY(); - } else if (isUnsupportedTerm()) { - size_t len; - - printf("%s",prompt); - fflush(stdout); - if (fgets(buf,LINENOISE_MAX_LINE,stdin) == NULL) return NULL; - len = strlen(buf); - while(len && (buf[len-1] == '\n' || buf[len-1] == '\r')) { - len--; - buf[len] = '\0'; - } - return strdup(buf); - } else { - const char *retval = linenoiseBlockingEdit(STDIN_FILENO,STDOUT_FILENO,buf,LINENOISE_MAX_LINE,prompt); - return retval; - } -} - -/* This is just a wrapper the user may want to call in order to make sure - * the linenoise returned buffer is freed with the same allocator it was - * created with. Useful when the main program is using an alternative - * allocator. */ -void linenoiseFree(void *ptr) { - if (ptr == linenoiseEditMore) return; // Protect from API misuse. - free(ptr); -} - -/* ================================ History ================================= */ - -/* Free the history, but does not reset it. Only used when we have to - * exit() to avoid memory leaks are reported by valgrind & co. */ -static void freeHistory(void) { - if (history) { - int j; - - for (j = 0; j < history_len; j++) - free(history[j]); - free(history); - } -} - -/* At exit we'll try to fix the terminal to the initial conditions. */ -static void linenoiseAtExit(void) { - disableRawMode(STDIN_FILENO); - freeHistory(); -} - -/* This is the API call to add a new entry in the linenoise history. - * It uses a fixed array of char pointers that are shifted (memmoved) - * when the history max length is reached in order to remove the older - * entry and make room for the new one, so it is not exactly suitable for huge - * histories, but will work well for a few hundred of entries. - * - * Using a circular buffer is smarter, but a bit more complex to handle. */ -int linenoiseHistoryAdd(const char *line) { - char *linecopy; - - if (history_max_len == 0) return 0; - - /* Initialization on first call. */ - if (history == NULL) { - history = (char**) malloc(sizeof(char*)*history_max_len); - if (history == NULL) return 0; - memset(history,0,(sizeof(char*)*history_max_len)); - } - - /* Don't add duplicated lines. */ - if (history_len && !strcmp(history[history_len-1], line)) return 0; - - /* Add an heap allocated copy of the line in the history. - * If we reached the max length, remove the older line. */ - linecopy = strdup(line); - if (!linecopy) return 0; - if (history_len == history_max_len) { - free(history[0]); - memmove(history,history+1,sizeof(char*)*(history_max_len-1)); - history_len--; - } - history[history_len] = linecopy; - history_len++; - return 1; -} - -/* Set the maximum length for the history. This function can be called even - * if there is already some history, the function will make sure to retain - * just the latest 'len' elements if the new history length value is smaller - * than the amount of items already inside the history. */ -int linenoiseHistorySetMaxLen(int len) { - char **new_ptr; - - if (len < 1) return 0; - if (history) { - int tocopy = history_len; - - new_ptr = (char**) malloc(sizeof(char*)*len); - if (new_ptr == NULL) return 0; - - /* If we can't copy everything, free the elements we'll not use. */ - if (len < tocopy) { - int j; - - for (j = 0; j < tocopy-len; j++) free(history[j]); - tocopy = len; - } - memset(new_ptr,0,sizeof(char*)*len); - memcpy(new_ptr,history+(history_len-tocopy), sizeof(char*)*tocopy); - free(history); - history = new_ptr; - } - history_max_len = len; - if (history_len > history_max_len) - history_len = history_max_len; - return 1; -} - -/* Save the history in the specified file. On success 0 is returned - * otherwise -1 is returned. */ -int linenoiseHistorySave(const char *filename) { - mode_t old_umask = umask(S_IXUSR|S_IRWXG|S_IRWXO); - File file; - file.open(filename, "w"); - umask(old_umask); - if (file.file == NULL) { - return -1; - } - chmod(filename,S_IRUSR|S_IWUSR); - for (int j = 0; j < history_len; ++j) { - fprintf(file.file, "%s\n", history[j]); - } - - return 0; -} - -/* Load the history from the specified file. If the file does not exist - * zero is returned and no operation is performed. - * - * If the file exists and the operation succeeded 0 is returned, otherwise - * on error -1 is returned. */ -int linenoiseHistoryLoad(const char *filename) { - File file; - file.open(filename, "r"); - char buf[LINENOISE_MAX_LINE]; - if (file.file == NULL) { - return -1; - } - - while (fgets(buf, LINENOISE_MAX_LINE, file.file) != NULL) { - char *p; - - p = strchr(buf,'\r'); - if (!p) p = strchr(buf,'\n'); - if (p) *p = '\0'; - linenoiseHistoryAdd(buf); - } - return 0; -} -#endif diff --git a/tools/run/linenoise.cpp/linenoise.h b/tools/run/linenoise.cpp/linenoise.h deleted file mode 100644 index 9823ca36d0..0000000000 --- a/tools/run/linenoise.cpp/linenoise.h +++ /dev/null @@ -1,137 +0,0 @@ -/* linenoise.h -- VERSION 1.0 - * - * Guerrilla line editing library against the idea that a line editing lib - * needs to be 20,000 lines of C++ code. - * - * See linenoise.cpp for more information. - * - * ------------------------------------------------------------------------ - * - * Copyright (c) 2010-2023, Salvatore Sanfilippo - * Copyright (c) 2010-2013, Pieter Noordhuis - * Copyright (c) 2025, Eric Curtin - * - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -#ifndef __LINENOISE_H -#define __LINENOISE_H - -#ifdef __cplusplus -extern "C" { -#endif - -#include /* For size_t. */ -#include - -extern const char * linenoiseEditMore; - -/* The linenoiseState structure represents the state during line editing. - * We pass this state to functions implementing specific editing - * functionalities. */ -struct linenoiseState { - int in_completion; /* The user pressed TAB and we are now in completion - * mode, so input is handled by completeLine(). */ - size_t completion_idx; /* Index of next completion to propose. */ - int ifd; /* Terminal stdin file descriptor. */ - int ofd; /* Terminal stdout file descriptor. */ - char * buf; /* Edited line buffer. */ - size_t buflen; /* Edited line buffer size. */ - const char * prompt; /* Prompt to display. */ - size_t plen; /* Prompt length. */ - size_t pos; /* Current cursor position. */ - size_t oldcolpos; /* Previous refresh cursor column position. */ - size_t len; /* Current edited line length. */ - size_t cols; /* Number of columns in terminal. */ - size_t oldrows; /* Rows used by last refreshed line (multiline mode) */ - int history_index; /* The history index we are currently editing. */ -}; - -struct linenoiseCompletions { - size_t len = 0; - char ** cvec = nullptr; - bool to_free = true; - - ~linenoiseCompletions() { - if (!to_free) { - return; - } - - for (size_t i = 0; i < len; ++i) { - free(cvec[i]); - } - - free(cvec); - } -}; - -/* Non blocking API. */ -int linenoiseEditStart(struct linenoiseState * l, int stdin_fd, int stdout_fd, char * buf, size_t buflen, - const char * prompt); -const char * linenoiseEditFeed(struct linenoiseState * l); -void linenoiseEditStop(struct linenoiseState * l); -void linenoiseHide(struct linenoiseState * l); -void linenoiseShow(struct linenoiseState * l); - -/* Blocking API. */ -const char * linenoise(const char * prompt); -void linenoiseFree(void * ptr); - -/* Completion API. */ -typedef void(linenoiseCompletionCallback)(const char *, linenoiseCompletions *); -typedef const char *(linenoiseHintsCallback) (const char *, int * color, int * bold); -typedef void(linenoiseFreeHintsCallback)(const char *); -void linenoiseSetCompletionCallback(linenoiseCompletionCallback *); -void linenoiseSetHintsCallback(linenoiseHintsCallback *); -void linenoiseSetFreeHintsCallback(linenoiseFreeHintsCallback *); -void linenoiseAddCompletion(linenoiseCompletions *, const char *); - -/* History API. */ -int linenoiseHistoryAdd(const char * line); -int linenoiseHistorySetMaxLen(int len); -int linenoiseHistorySave(const char * filename); -int linenoiseHistoryLoad(const char * filename); - -/* Other utilities. */ -void linenoiseClearScreen(void); -void linenoiseSetMultiLine(int ml); -void linenoisePrintKeyCodes(void); -void linenoiseMaskModeEnable(void); -void linenoiseMaskModeDisable(void); - -/* Encoding functions. */ -typedef size_t(linenoisePrevCharLen)(const char * buf, size_t buf_len, size_t pos, size_t * col_len); -typedef size_t(linenoiseNextCharLen)(const char * buf, size_t buf_len, size_t pos, size_t * col_len); -typedef size_t(linenoiseReadCode)(int fd, char * buf, size_t buf_len, int * c); - -void linenoiseSetEncodingFunctions(linenoisePrevCharLen * prevCharLenFunc, linenoiseNextCharLen * nextCharLenFunc, - linenoiseReadCode * readCodeFunc); - -#ifdef __cplusplus -} -#endif - -#endif /* __LINENOISE_H */ diff --git a/tools/run/run.cpp b/tools/run/run.cpp deleted file mode 100644 index b90a7253c4..0000000000 --- a/tools/run/run.cpp +++ /dev/null @@ -1,1408 +0,0 @@ -#include "chat.h" -#include "common.h" -#include "llama-cpp.h" -#include "log.h" - -#include "linenoise.cpp/linenoise.h" - -#define JSON_ASSERT GGML_ASSERT -#include - -#if defined(_WIN32) -# define WIN32_LEAN_AND_MEAN -# ifndef NOMINMAX -# define NOMINMAX -# endif -# include -# include -#else -# include -# include -# include -#endif - -#if defined(LLAMA_USE_CURL) -# include -#else -# include "http.h" -#endif - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32) -[[noreturn]] static void sigint_handler(int) { - printf("\n" LOG_COL_DEFAULT); - exit(0); // not ideal, but it's the only way to guarantee exit in all cases -} -#endif - -GGML_ATTRIBUTE_FORMAT(1, 2) -static int printe(const char * fmt, ...) { - va_list args; - va_start(args, fmt); - const int ret = vfprintf(stderr, fmt, args); - va_end(args); - - return ret; -} - -static std::string strftime_fmt(const char * fmt, const std::tm & tm) { - std::ostringstream oss; - oss << std::put_time(&tm, fmt); - - return oss.str(); -} - -class Opt { - public: - int init(int argc, const char ** argv) { - ctx_params = llama_context_default_params(); - model_params = llama_model_default_params(); - context_size_default = ctx_params.n_batch; - n_threads_default = ctx_params.n_threads; - ngl_default = model_params.n_gpu_layers; - common_params_sampling sampling; - temperature_default = sampling.temp; - - if (argc < 2) { - printe("Error: No arguments provided.\n"); - print_help(); - return 1; - } - - // Parse arguments - if (parse(argc, argv)) { - printe("Error: Failed to parse arguments.\n"); - print_help(); - return 1; - } - - // If help is requested, show help and exit - if (help) { - print_help(); - return 2; - } - - ctx_params.n_batch = context_size >= 0 ? context_size : context_size_default; - ctx_params.n_ctx = ctx_params.n_batch; - ctx_params.n_threads = ctx_params.n_threads_batch = n_threads >= 0 ? n_threads : n_threads_default; - model_params.n_gpu_layers = ngl >= 0 ? ngl : ngl_default; - temperature = temperature >= 0 ? temperature : temperature_default; - - return 0; // Success - } - - llama_context_params ctx_params; - llama_model_params model_params; - std::string model_; - std::string chat_template_file; - std::string user; - bool use_jinja = false; - int context_size = -1, ngl = -1, n_threads = -1; - float temperature = -1; - bool verbose = false; - - private: - int context_size_default = -1, ngl_default = -1, n_threads_default = -1; - float temperature_default = -1; - bool help = false; - - bool parse_flag(const char ** argv, int i, const char * short_opt, const char * long_opt) { - return strcmp(argv[i], short_opt) == 0 || strcmp(argv[i], long_opt) == 0; - } - - int handle_option_with_value(int argc, const char ** argv, int & i, int & option_value) { - if (i + 1 >= argc) { - return 1; - } - - option_value = std::atoi(argv[++i]); - - return 0; - } - - int handle_option_with_value(int argc, const char ** argv, int & i, float & option_value) { - if (i + 1 >= argc) { - return 1; - } - - option_value = std::atof(argv[++i]); - - return 0; - } - - int handle_option_with_value(int argc, const char ** argv, int & i, std::string & option_value) { - if (i + 1 >= argc) { - return 1; - } - - option_value = argv[++i]; - - return 0; - } - - int parse_options_with_value(int argc, const char ** argv, int & i, bool & options_parsing) { - if (options_parsing && (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0)) { - if (handle_option_with_value(argc, argv, i, context_size) == 1) { - return 1; - } - } else if (options_parsing && - (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "-ngl") == 0 || strcmp(argv[i], "--ngl") == 0)) { - if (handle_option_with_value(argc, argv, i, ngl) == 1) { - return 1; - } - } else if (options_parsing && (strcmp(argv[i], "-t") == 0 || strcmp(argv[i], "--threads") == 0)) { - if (handle_option_with_value(argc, argv, i, n_threads) == 1) { - return 1; - } - } else if (options_parsing && strcmp(argv[i], "--temp") == 0) { - if (handle_option_with_value(argc, argv, i, temperature) == 1) { - return 1; - } - } else if (options_parsing && strcmp(argv[i], "--chat-template-file") == 0) { - if (handle_option_with_value(argc, argv, i, chat_template_file) == 1) { - return 1; - } - use_jinja = true; - } else { - return 2; - } - - return 0; - } - - int parse_options(const char ** argv, int & i, bool & options_parsing) { - if (options_parsing && (parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) { - verbose = true; - } else if (options_parsing && strcmp(argv[i], "--jinja") == 0) { - use_jinja = true; - } else if (options_parsing && parse_flag(argv, i, "-h", "--help")) { - help = true; - return 0; - } else if (options_parsing && strcmp(argv[i], "--") == 0) { - options_parsing = false; - } else { - return 2; - } - - return 0; - } - - int parse_positional_args(const char ** argv, int & i, int & positional_args_i) { - if (positional_args_i == 0) { - if (!argv[i][0] || argv[i][0] == '-') { - return 1; - } - - ++positional_args_i; - model_ = argv[i]; - } else if (positional_args_i == 1) { - ++positional_args_i; - user = argv[i]; - } else { - user += " " + std::string(argv[i]); - } - - return 0; - } - - int parse(int argc, const char ** argv) { - bool options_parsing = true; - for (int i = 1, positional_args_i = 0; i < argc; ++i) { - int ret = parse_options_with_value(argc, argv, i, options_parsing); - if (ret == 0) { - continue; - } else if (ret == 1) { - return ret; - } - - ret = parse_options(argv, i, options_parsing); - if (ret == 0) { - continue; - } else if (ret == 1) { - return ret; - } - - if (parse_positional_args(argv, i, positional_args_i)) { - return 1; - } - } - - if (model_.empty()) { - return 1; - } - - return 0; - } - - void print_help() const { - printf( - "Description:\n" - " Runs a llm\n" - "\n" - "Usage:\n" - " llama-run [options] model [prompt]\n" - "\n" - "Options:\n" - " -c, --context-size \n" - " Context size (default: %d)\n" - " --chat-template-file \n" - " Path to the file containing the chat template to use with the model.\n" - " Only supports jinja templates and implicitly sets the --jinja flag.\n" - " --jinja\n" - " Use jinja templating for the chat template of the model\n" - " -n, -ngl, --ngl \n" - " Number of GPU layers (default: %d)\n" - " --temp \n" - " Temperature (default: %.1f)\n" - " -t, --threads \n" - " Number of threads to use during generation (default: %d)\n" - " -v, --verbose, --log-verbose\n" - " Set verbosity level to infinity (i.e. log all messages, useful for debugging)\n" - " -h, --help\n" - " Show help message\n" - "\n" - "Commands:\n" - " model\n" - " Model is a string with an optional prefix of \n" - " huggingface:// (hf://), modelscope:// (ms://), ollama://, https:// or file://.\n" - " If no protocol is specified and a file exists in the specified\n" - " path, file:// is assumed, otherwise if a file does not exist in\n" - " the specified path, ollama:// is assumed. Models that are being\n" - " pulled are downloaded with .partial extension while being\n" - " downloaded and then renamed as the file without the .partial\n" - " extension when complete.\n" - "\n" - "Examples:\n" - " llama-run llama3\n" - " llama-run ollama://granite-code\n" - " llama-run ollama://smollm:135m\n" - " llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf\n" - " llama-run " - "huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf\n" - " llama-run ms://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf\n" - " llama-run " - "modelscope://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf\n" - " llama-run https://example.com/some-file1.gguf\n" - " llama-run some-file2.gguf\n" - " llama-run file://some-file3.gguf\n" - " llama-run --ngl 999 some-file4.gguf\n" - " llama-run --ngl 999 some-file5.gguf Hello World\n", - context_size_default, ngl_default, temperature_default, n_threads_default); - } -}; - -struct progress_data { - size_t file_size = 0; - std::chrono::steady_clock::time_point start_time = std::chrono::steady_clock::now(); - bool printed = false; -}; - -static int get_terminal_width() { -#if defined(_WIN32) - CONSOLE_SCREEN_BUFFER_INFO csbi; - GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi); - return csbi.srWindow.Right - csbi.srWindow.Left + 1; -#else - struct winsize w; - ioctl(STDOUT_FILENO, TIOCGWINSZ, &w); - return w.ws_col; -#endif -} - -class File { - public: - FILE * file = nullptr; - - FILE * open(const std::string & filename, const char * mode) { - file = ggml_fopen(filename.c_str(), mode); - - return file; - } - - int lock() { - if (file) { -# ifdef _WIN32 - fd = _fileno(file); - hFile = (HANDLE) _get_osfhandle(fd); - if (hFile == INVALID_HANDLE_VALUE) { - fd = -1; - - return 1; - } - - OVERLAPPED overlapped = {}; - if (!LockFileEx(hFile, LOCKFILE_EXCLUSIVE_LOCK | LOCKFILE_FAIL_IMMEDIATELY, 0, MAXDWORD, MAXDWORD, - &overlapped)) { - fd = -1; - - return 1; - } -# else - fd = fileno(file); - if (flock(fd, LOCK_EX | LOCK_NB) != 0) { - fd = -1; - - return 1; - } -# endif - } - - return 0; - } - - std::string to_string() { - fseek(file, 0, SEEK_END); - const size_t size = ftell(file); - fseek(file, 0, SEEK_SET); - std::string out; - out.resize(size); - const size_t read_size = fread(&out[0], 1, size, file); - if (read_size != size) { - printe("Error reading file: %s", strerror(errno)); - } - - return out; - } - - ~File() { - if (fd >= 0) { -# ifdef _WIN32 - if (hFile != INVALID_HANDLE_VALUE) { - OVERLAPPED overlapped = {}; - UnlockFileEx(hFile, 0, MAXDWORD, MAXDWORD, &overlapped); - } -# else - flock(fd, LOCK_UN); -# endif - } - - if (file) { - fclose(file); - } - } - - private: - int fd = -1; -# ifdef _WIN32 - HANDLE hFile = nullptr; -# endif -}; - -class HttpClient { - public: - int init(const std::string & url, const std::vector & headers, const std::string & output_file, - const bool progress, std::string * response_str = nullptr) { - if (std::filesystem::exists(output_file)) { - return 0; - } - - std::string output_file_partial; - - if (!output_file.empty()) { - output_file_partial = output_file + ".partial"; - } - - if (download(url, headers, output_file_partial, progress, response_str)) { - return 1; - } - - if (!output_file.empty()) { - try { - std::filesystem::rename(output_file_partial, output_file); - } catch (const std::filesystem::filesystem_error & e) { - printe("Failed to rename '%s' to '%s': %s\n", output_file_partial.c_str(), output_file.c_str(), e.what()); - return 1; - } - } - - return 0; - } - -#ifdef LLAMA_USE_CURL - - ~HttpClient() { - if (chunk) { - curl_slist_free_all(chunk); - } - - if (curl) { - curl_easy_cleanup(curl); - } - } - - private: - CURL * curl = nullptr; - struct curl_slist * chunk = nullptr; - - int download(const std::string & url, const std::vector & headers, const std::string & output_file, - const bool progress, std::string * response_str = nullptr) { - curl = curl_easy_init(); - if (!curl) { - return 1; - } - - progress_data data; - File out; - if (!output_file.empty()) { - if (!out.open(output_file, "ab")) { - printe("Failed to open file for writing\n"); - - return 1; - } - - if (out.lock()) { - printe("Failed to exclusively lock file\n"); - - return 1; - } - } - - set_write_options(response_str, out); - data.file_size = set_resume_point(output_file); - set_progress_options(progress, data); - set_headers(headers); - CURLcode res = perform(url); - if (res != CURLE_OK){ - printe("Fetching resource '%s' failed: %s\n", url.c_str(), curl_easy_strerror(res)); - return 1; - } - - return 0; - } - - void set_write_options(std::string * response_str, const File & out) { - if (response_str) { - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, capture_data); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, response_str); - } else { - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_data); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, out.file); - } - } - - size_t set_resume_point(const std::string & output_file) { - size_t file_size = 0; - if (std::filesystem::exists(output_file)) { - file_size = std::filesystem::file_size(output_file); - curl_easy_setopt(curl, CURLOPT_RESUME_FROM_LARGE, static_cast(file_size)); - } - - return file_size; - } - - void set_progress_options(bool progress, progress_data & data) { - if (progress) { - curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); - curl_easy_setopt(curl, CURLOPT_XFERINFODATA, &data); - curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, update_progress); - } - } - - void set_headers(const std::vector & headers) { - if (!headers.empty()) { - if (chunk) { - curl_slist_free_all(chunk); - chunk = 0; - } - - for (const auto & header : headers) { - chunk = curl_slist_append(chunk, header.c_str()); - } - - curl_easy_setopt(curl, CURLOPT_HTTPHEADER, chunk); - } - } - - CURLcode perform(const std::string & url) { - curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); - curl_easy_setopt(curl, CURLOPT_DEFAULT_PROTOCOL, "https"); - curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L); -#ifdef _WIN32 - curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); -#endif - return curl_easy_perform(curl); - } - -#else // LLAMA_USE_CURL is not defined - -#define curl_off_t long long // temporary hack - - private: - // this is a direct translation of the cURL download() above - int download(const std::string & url, const std::vector & headers_vec, const std::string & output_file, - const bool progress, std::string * response_str = nullptr) { - try { - auto [cli, url_parts] = common_http_client(url); - - httplib::Headers headers; - for (const auto & h : headers_vec) { - size_t pos = h.find(':'); - if (pos != std::string::npos) { - headers.emplace(h.substr(0, pos), h.substr(pos + 2)); - } - } - - File out; - if (!output_file.empty()) { - if (!out.open(output_file, "ab")) { - printe("Failed to open file for writing\n"); - return 1; - } - if (out.lock()) { - printe("Failed to exclusively lock file\n"); - return 1; - } - } - - size_t resume_offset = 0; - if (!output_file.empty() && std::filesystem::exists(output_file)) { - resume_offset = std::filesystem::file_size(output_file); - if (resume_offset > 0) { - headers.emplace("Range", "bytes=" + std::to_string(resume_offset) + "-"); - } - } - - progress_data data; - data.file_size = resume_offset; - - long long total_size = 0; - long long received_this_session = 0; - - auto response_handler = - [&](const httplib::Response & response) { - if (resume_offset > 0 && response.status != 206) { - printe("\nServer does not support resuming. Restarting download.\n"); - out.file = freopen(output_file.c_str(), "wb", out.file); - if (!out.file) { - return false; - } - data.file_size = 0; - } - if (progress) { - if (response.has_header("Content-Length")) { - total_size = std::stoll(response.get_header_value("Content-Length")); - } else if (response.has_header("Content-Range")) { - auto range = response.get_header_value("Content-Range"); - auto slash = range.find('/'); - if (slash != std::string::npos) { - total_size = std::stoll(range.substr(slash + 1)); - } - } - } - return true; - }; - - auto content_receiver = - [&](const char * chunk, size_t length) { - if (out.file && fwrite(chunk, 1, length, out.file) != length) { - return false; - } - if (response_str) { - response_str->append(chunk, length); - } - received_this_session += length; - - if (progress && total_size > 0) { - update_progress(&data, total_size, received_this_session, 0, 0); - } - return true; - }; - - auto res = cli.Get(url_parts.path, headers, response_handler, content_receiver); - - if (data.printed) { - printe("\n"); - } - - if (!res) { - auto err = res.error(); - printe("Fetching resource '%s' failed: %s\n", url.c_str(), httplib::to_string(err).c_str()); - return 1; - } - - if (res->status >= 400) { - printe("Fetching resource '%s' failed with status code: %d\n", url.c_str(), res->status); - return 1; - } - - } catch (const std::exception & e) { - printe("HTTP request failed: %s\n", e.what()); - return 1; - } - return 0; - } - -#endif // LLAMA_USE_CURL - - static std::string human_readable_time(double seconds) { - int hrs = static_cast(seconds) / 3600; - int mins = (static_cast(seconds) % 3600) / 60; - int secs = static_cast(seconds) % 60; - - if (hrs > 0) { - return string_format("%dh %02dm %02ds", hrs, mins, secs); - } else if (mins > 0) { - return string_format("%dm %02ds", mins, secs); - } else { - return string_format("%ds", secs); - } - } - - static std::string human_readable_size(curl_off_t size) { - static const char * suffix[] = { "B", "KB", "MB", "GB", "TB" }; - char length = sizeof(suffix) / sizeof(suffix[0]); - int i = 0; - double dbl_size = size; - if (size > 1024) { - for (i = 0; (size / 1024) > 0 && i < length - 1; i++, size /= 1024) { - dbl_size = size / 1024.0; - } - } - - return string_format("%.2f %s", dbl_size, suffix[i]); - } - - static int update_progress(void * ptr, curl_off_t total_to_download, curl_off_t now_downloaded, curl_off_t, - curl_off_t) { - progress_data * data = static_cast(ptr); - if (total_to_download <= 0) { - return 0; - } - - total_to_download += data->file_size; - const curl_off_t now_downloaded_plus_file_size = now_downloaded + data->file_size; - const curl_off_t percentage = calculate_percentage(now_downloaded_plus_file_size, total_to_download); - std::string progress_prefix = generate_progress_prefix(percentage); - - const double speed = calculate_speed(now_downloaded, data->start_time); - const double tim = (total_to_download - now_downloaded) / speed; - std::string progress_suffix = - generate_progress_suffix(now_downloaded_plus_file_size, total_to_download, speed, tim); - - int progress_bar_width = calculate_progress_bar_width(progress_prefix, progress_suffix); - std::string progress_bar; - generate_progress_bar(progress_bar_width, percentage, progress_bar); - - print_progress(progress_prefix, progress_bar, progress_suffix); - data->printed = true; - - return 0; - } - - static curl_off_t calculate_percentage(curl_off_t now_downloaded_plus_file_size, curl_off_t total_to_download) { - return (now_downloaded_plus_file_size * 100) / total_to_download; - } - - static std::string generate_progress_prefix(curl_off_t percentage) { - return string_format("%3ld%% |", static_cast(percentage)); - } - - static double calculate_speed(curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) { - const auto now = std::chrono::steady_clock::now(); - const std::chrono::duration elapsed_seconds = now - start_time; - return now_downloaded / elapsed_seconds.count(); - } - - static std::string generate_progress_suffix(curl_off_t now_downloaded_plus_file_size, curl_off_t total_to_download, - double speed, double estimated_time) { - const int width = 10; - return string_format("%*s/%*s%*s/s%*s", width, human_readable_size(now_downloaded_plus_file_size).c_str(), - width, human_readable_size(total_to_download).c_str(), width, - human_readable_size(speed).c_str(), width, human_readable_time(estimated_time).c_str()); - } - - static int calculate_progress_bar_width(const std::string & progress_prefix, const std::string & progress_suffix) { - int progress_bar_width = get_terminal_width() - progress_prefix.size() - progress_suffix.size() - 3; - if (progress_bar_width < 1) { - progress_bar_width = 1; - } - - return progress_bar_width; - } - - static std::string generate_progress_bar(int progress_bar_width, curl_off_t percentage, - std::string & progress_bar) { - const curl_off_t pos = (percentage * progress_bar_width) / 100; - for (int i = 0; i < progress_bar_width; ++i) { - progress_bar.append((i < pos) ? "ā–ˆ" : " "); - } - - return progress_bar; - } - - static void print_progress(const std::string & progress_prefix, const std::string & progress_bar, - const std::string & progress_suffix) { - printe("\r" LOG_CLR_TO_EOL "%s%s| %s", progress_prefix.c_str(), progress_bar.c_str(), progress_suffix.c_str()); - } - // Function to write data to a file - static size_t write_data(void * ptr, size_t size, size_t nmemb, void * stream) { - FILE * out = static_cast(stream); - return fwrite(ptr, size, nmemb, out); - } - - // Function to capture data into a string - static size_t capture_data(void * ptr, size_t size, size_t nmemb, void * stream) { - std::string * str = static_cast(stream); - str->append(static_cast(ptr), size * nmemb); - return size * nmemb; - } - -}; - -class LlamaData { - public: - llama_model_ptr model; - llama_sampler_ptr sampler; - llama_context_ptr context; - std::vector messages; // TODO: switch to common_chat_msg - std::list msg_strs; - std::vector fmtted; - - int init(Opt & opt) { - model = initialize_model(opt); - if (!model) { - return 1; - } - - context = initialize_context(model, opt); - if (!context) { - return 1; - } - - sampler = initialize_sampler(opt); - - return 0; - } - - private: - int download(const std::string & url, const std::string & output_file, const bool progress, - const std::vector & headers = {}, std::string * response_str = nullptr) { - HttpClient http; - if (http.init(url, headers, output_file, progress, response_str)) { - return 1; - } - - return 0; - } - - // Helper function to handle model tag extraction and URL construction - std::pair extract_model_and_tag(std::string & model, const std::string & base_url) { - std::string model_tag = "latest"; - const size_t colon_pos = model.find(':'); - if (colon_pos != std::string::npos) { - model_tag = model.substr(colon_pos + 1); - model = model.substr(0, colon_pos); - } - - std::string url = base_url + model + "/manifests/" + model_tag; - - return { model, url }; - } - - // Helper function to download and parse the manifest - int download_and_parse_manifest(const std::string & url, const std::vector & headers, - nlohmann::json & manifest) { - std::string manifest_str; - int ret = download(url, "", false, headers, &manifest_str); - if (ret) { - return ret; - } - - manifest = nlohmann::json::parse(manifest_str); - - return 0; - } - - int dl_from_endpoint(std::string & model_endpoint, std::string & model, const std::string & bn) { - // Find the second occurrence of '/' after protocol string - size_t pos = model.find('/'); - pos = model.find('/', pos + 1); - std::string hfr, hff; - std::vector headers = { "User-Agent: llama-cpp", "Accept: application/json" }; - std::string url; - - if (pos == std::string::npos) { - auto [model_name, manifest_url] = extract_model_and_tag(model, model_endpoint + "v2/"); - hfr = model_name; - - nlohmann::json manifest; - int ret = download_and_parse_manifest(manifest_url, headers, manifest); - if (ret) { - return ret; - } - - hff = manifest["ggufFile"]["rfilename"]; - } else { - hfr = model.substr(0, pos); - hff = model.substr(pos + 1); - } - - url = model_endpoint + hfr + "/resolve/main/" + hff; - - return download(url, bn, true, headers); - } - - int modelscope_dl(std::string & model, const std::string & bn) { - std::string model_endpoint = "https://modelscope.cn/models/"; - return dl_from_endpoint(model_endpoint, model, bn); - } - - int huggingface_dl(std::string & model, const std::string & bn) { - std::string model_endpoint = get_model_endpoint(); - return dl_from_endpoint(model_endpoint, model, bn); - } - - int ollama_dl(std::string & model, const std::string & bn) { - const std::vector headers = { "Accept: application/vnd.docker.distribution.manifest.v2+json" }; - if (model.find('/') == std::string::npos) { - model = "library/" + model; - } - - auto [model_name, manifest_url] = extract_model_and_tag(model, "https://registry.ollama.ai/v2/"); - nlohmann::json manifest; - int ret = download_and_parse_manifest(manifest_url, {}, manifest); - if (ret) { - return ret; - } - - std::string layer; - for (const auto & l : manifest["layers"]) { - if (l["mediaType"] == "application/vnd.ollama.image.model") { - layer = l["digest"]; - break; - } - } - - std::string blob_url = "https://registry.ollama.ai/v2/" + model_name + "/blobs/" + layer; - - return download(blob_url, bn, true, headers); - } - - int github_dl(const std::string & model, const std::string & bn) { - std::string repository = model; - std::string branch = "main"; - const size_t at_pos = model.find('@'); - if (at_pos != std::string::npos) { - repository = model.substr(0, at_pos); - branch = model.substr(at_pos + 1); - } - - const std::vector repo_parts = string_split(repository, "/"); - if (repo_parts.size() < 3) { - printe("Invalid GitHub repository format\n"); - return 1; - } - - const std::string & org = repo_parts[0]; - const std::string & project = repo_parts[1]; - std::string url = "https://raw.githubusercontent.com/" + org + "/" + project + "/" + branch; - for (size_t i = 2; i < repo_parts.size(); ++i) { - url += "/" + repo_parts[i]; - } - - return download(url, bn, true); - } - - int s3_dl(const std::string & model, const std::string & bn) { - const size_t slash_pos = model.find('/'); - if (slash_pos == std::string::npos) { - return 1; - } - - const std::string bucket = model.substr(0, slash_pos); - const std::string key = model.substr(slash_pos + 1); - const char * access_key = std::getenv("AWS_ACCESS_KEY_ID"); - const char * secret_key = std::getenv("AWS_SECRET_ACCESS_KEY"); - if (!access_key || !secret_key) { - printe("AWS credentials not found in environment\n"); - return 1; - } - - // Generate AWS Signature Version 4 headers - // (Implementation requires HMAC-SHA256 and date handling) - // Get current timestamp - const time_t now = time(nullptr); - const tm tm = *gmtime(&now); - const std::string date = strftime_fmt("%Y%m%d", tm); - const std::string datetime = strftime_fmt("%Y%m%dT%H%M%SZ", tm); - const std::vector headers = { - "Authorization: AWS4-HMAC-SHA256 Credential=" + std::string(access_key) + "/" + date + - "/us-east-1/s3/aws4_request", - "x-amz-content-sha256: UNSIGNED-PAYLOAD", "x-amz-date: " + datetime - }; - - const std::string url = "https://" + bucket + ".s3.amazonaws.com/" + key; - - return download(url, bn, true, headers); - } - - std::string basename(const std::string & path) { - const size_t pos = path.find_last_of("/\\"); - if (pos == std::string::npos) { - return path; - } - - return path.substr(pos + 1); - } - - int rm_until_substring(std::string & model_, const std::string & substring) { - const std::string::size_type pos = model_.find(substring); - if (pos == std::string::npos) { - return 1; - } - - model_ = model_.substr(pos + substring.size()); // Skip past the substring - return 0; - } - - int resolve_model(std::string & model_) { - int ret = 0; - if (string_starts_with(model_, "file://") || std::filesystem::exists(model_)) { - rm_until_substring(model_, "://"); - - return ret; - } - - const std::string bn = basename(model_); - if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://") || - string_starts_with(model_, "hf.co/")) { - rm_until_substring(model_, "hf.co/"); - rm_until_substring(model_, "://"); - ret = huggingface_dl(model_, bn); - } else if (string_starts_with(model_, "ms://") || string_starts_with(model_, "modelscope://")) { - rm_until_substring(model_, "://"); - ret = modelscope_dl(model_, bn); - } else if ((string_starts_with(model_, "https://") || string_starts_with(model_, "http://")) && - !string_starts_with(model_, "https://ollama.com/library/")) { - ret = download(model_, bn, true); - } else if (string_starts_with(model_, "github:") || string_starts_with(model_, "github://")) { - rm_until_substring(model_, "github:"); - rm_until_substring(model_, "://"); - ret = github_dl(model_, bn); - } else if (string_starts_with(model_, "s3://")) { - rm_until_substring(model_, "://"); - ret = s3_dl(model_, bn); - } else { // ollama:// or nothing - rm_until_substring(model_, "ollama.com/library/"); - rm_until_substring(model_, "://"); - ret = ollama_dl(model_, bn); - } - - model_ = bn; - - return ret; - } - - // Initializes the model and returns a unique pointer to it - llama_model_ptr initialize_model(Opt & opt) { - ggml_backend_load_all(); - resolve_model(opt.model_); - printe("\r" LOG_CLR_TO_EOL "Loading model"); - llama_model_ptr model(llama_model_load_from_file(opt.model_.c_str(), opt.model_params)); - if (!model) { - printe("%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str()); - } - - printe("\r" LOG_CLR_TO_EOL); - return model; - } - - // Initializes the context with the specified parameters - llama_context_ptr initialize_context(const llama_model_ptr & model, const Opt & opt) { - llama_context_ptr context(llama_init_from_model(model.get(), opt.ctx_params)); - if (!context) { - printe("%s: error: failed to create the llama_context\n", __func__); - } - - return context; - } - - // Initializes and configures the sampler - llama_sampler_ptr initialize_sampler(const Opt & opt) { - llama_sampler_ptr sampler(llama_sampler_chain_init(llama_sampler_chain_default_params())); - llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1)); - llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(opt.temperature)); - llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); - - return sampler; - } -}; - -// Add a message to `messages` and store its content in `msg_strs` -static void add_message(const char * role, const std::string & text, LlamaData & llama_data) { - llama_data.msg_strs.push_back(std::move(text)); - llama_data.messages.push_back({ role, llama_data.msg_strs.back().c_str() }); -} - -// Function to apply the chat template and resize `formatted` if needed -static int apply_chat_template(const struct common_chat_templates * tmpls, LlamaData & llama_data, const bool append, bool use_jinja) { - common_chat_templates_inputs inputs; - for (const auto & msg : llama_data.messages) { - common_chat_msg cmsg; - cmsg.role = msg.role; - cmsg.content = msg.content; - inputs.messages.push_back(cmsg); - } - inputs.add_generation_prompt = append; - inputs.use_jinja = use_jinja; - - auto chat_params = common_chat_templates_apply(tmpls, inputs); - // TODO: use other params for tool calls. - auto result = chat_params.prompt; - llama_data.fmtted.resize(result.size() + 1); - memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1); - return result.size(); -} - -// Function to tokenize the prompt -static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt, - std::vector & prompt_tokens, const LlamaData & llama_data) { - const bool is_first = llama_memory_seq_pos_max(llama_get_memory(llama_data.context.get()), 0) == -1; - int n_tokens = prompt.size() + 2 * is_first; - prompt_tokens.resize(n_tokens); - n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.size(), - prompt_tokens.data(), prompt_tokens.size(), - is_first, /*parse_special =*/true); - if (n_tokens == std::numeric_limits::min()) { - printe("tokenization failed: input too large\n"); - return -1; - } - if (n_tokens < 0) { - prompt_tokens.resize(-n_tokens); - int check = llama_tokenize(vocab, prompt.c_str(), prompt.size(), - prompt_tokens.data(), prompt_tokens.size(), - is_first, /*parse_special =*/true); - if (check != -n_tokens) { - printe("failed to tokenize the prompt (size mismatch)\n"); - return -1; - } - n_tokens = check; - } else { - prompt_tokens.resize(n_tokens); - } - return n_tokens; -} - -// Check if we have enough space in the context to evaluate this batch -static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) { - const int n_ctx = llama_n_ctx(ctx.get()); - const int n_ctx_used = llama_memory_seq_pos_max(llama_get_memory(ctx.get()), 0); - if (n_ctx_used + batch.n_tokens > n_ctx) { - printf(LOG_COL_DEFAULT "\n"); - printe("context size exceeded\n"); - return 1; - } - - return 0; -} - -// convert the token to a string -static int convert_token_to_string(const llama_vocab * vocab, const llama_token token_id, std::string & piece) { - char buf[256]; - int n = llama_token_to_piece(vocab, token_id, buf, sizeof(buf), 0, true); - if (n < 0) { - printe("failed to convert token to piece\n"); - return 1; - } - - piece = std::string(buf, n); - return 0; -} - -static void print_word_and_concatenate_to_response(const std::string & piece, std::string & response) { - printf("%s", piece.c_str()); - fflush(stdout); - response += piece; -} - -// helper function to evaluate a prompt and generate a response -static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) { - const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get()); - - std::vector tokens; - if (tokenize_prompt(vocab, prompt, tokens, llama_data) < 0) { - return 1; - } - - // prepare a batch for the prompt - llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size()); - llama_token new_token_id; - while (true) { - check_context_size(llama_data.context, batch); - if (llama_decode(llama_data.context.get(), batch)) { - printe("failed to decode\n"); - return 1; - } - - // sample the next token, check is it an end of generation? - new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1); - if (llama_vocab_is_eog(vocab, new_token_id)) { - break; - } - - std::string piece; - if (convert_token_to_string(vocab, new_token_id, piece)) { - return 1; - } - - print_word_and_concatenate_to_response(piece, response); - - // prepare the next batch with the sampled token - batch = llama_batch_get_one(&new_token_id, 1); - } - - printf(LOG_COL_DEFAULT); - return 0; -} - -static int read_user_input(std::string & user_input) { - static const char * prompt_prefix_env = std::getenv("LLAMA_PROMPT_PREFIX"); - static const char * prompt_prefix = prompt_prefix_env ? prompt_prefix_env : "> "; -#ifdef WIN32 - printf("\r" LOG_CLR_TO_EOL LOG_COL_DEFAULT "%s", prompt_prefix); - - std::getline(std::cin, user_input); - if (std::cin.eof()) { - printf("\n"); - return 1; - } -#else - std::unique_ptr line(const_cast(linenoise(prompt_prefix)), free); - if (!line) { - return 1; - } - - user_input = line.get(); -#endif - - if (user_input == "/bye") { - return 1; - } - - if (user_input.empty()) { - return 2; - } - -#ifndef WIN32 - linenoiseHistoryAdd(line.get()); -#endif - - return 0; // Should have data in happy path -} - -// Function to generate a response based on the prompt -static int generate_response(LlamaData & llama_data, const std::string & prompt, std::string & response, - const bool stdout_a_terminal) { - // Set response color - if (stdout_a_terminal) { - printf(LOG_COL_YELLOW); - } - - if (generate(llama_data, prompt, response)) { - printe("failed to generate response\n"); - return 1; - } - - // End response with color reset and newline - printf("\n%s", stdout_a_terminal ? LOG_COL_DEFAULT : ""); - return 0; -} - -// Helper function to apply the chat template and handle errors -static int apply_chat_template_with_error_handling(const common_chat_templates * tmpls, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { - const int new_len = apply_chat_template(tmpls, llama_data, append, use_jinja); - if (new_len < 0) { - printe("failed to apply the chat template\n"); - return -1; - } - - output_length = new_len; - return 0; -} - -// Helper function to handle user input -static int handle_user_input(std::string & user_input, const std::string & user) { - if (!user.empty()) { - user_input = user; - return 0; // No need for interactive input - } - - return read_user_input(user_input); // Returns true if input ends the loop -} - -static bool is_stdin_a_terminal() { -#if defined(_WIN32) - HANDLE hStdin = GetStdHandle(STD_INPUT_HANDLE); - DWORD mode; - return GetConsoleMode(hStdin, &mode); -#else - return isatty(STDIN_FILENO); -#endif -} - -static bool is_stdout_a_terminal() { -#if defined(_WIN32) - HANDLE hStdout = GetStdHandle(STD_OUTPUT_HANDLE); - DWORD mode; - return GetConsoleMode(hStdout, &mode); -#else - return isatty(STDOUT_FILENO); -#endif -} - -// Function to handle user input -static int get_user_input(std::string & user_input, const std::string & user) { - while (true) { - const int ret = handle_user_input(user_input, user); - if (ret == 1) { - return 1; - } - - if (ret == 2) { - continue; - } - - break; - } - - return 0; -} - -// Reads a chat template file to be used -static std::string read_chat_template_file(const std::string & chat_template_file) { - File file; - if (!file.open(chat_template_file, "r")) { - printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno)); - return ""; - } - - return file.to_string(); -} - -static int process_user_message(const Opt & opt, const std::string & user_input, LlamaData & llama_data, - const common_chat_templates_ptr & chat_templates, int & prev_len, - const bool stdout_a_terminal) { - add_message("user", opt.user.empty() ? user_input : opt.user, llama_data); - int new_len; - if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, opt.use_jinja) < 0) { - return 1; - } - - std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len); - std::string response; - if (generate_response(llama_data, prompt, response, stdout_a_terminal)) { - return 1; - } - - if (!opt.user.empty()) { - return 2; - } - - add_message("assistant", response, llama_data); - if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, opt.use_jinja) < 0) { - return 1; - } - - return 0; -} - -// Main chat loop function -static int chat_loop(LlamaData & llama_data, const Opt & opt) { - int prev_len = 0; - llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); - std::string chat_template; - if (!opt.chat_template_file.empty()) { - chat_template = read_chat_template_file(opt.chat_template_file); - } - - common_chat_templates_ptr chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template); - static const bool stdout_a_terminal = is_stdout_a_terminal(); - while (true) { - // Get user input - std::string user_input; - if (get_user_input(user_input, opt.user) == 1) { - return 0; - } - - const int ret = process_user_message(opt, user_input, llama_data, chat_templates, prev_len, stdout_a_terminal); - if (ret == 1) { - return 1; - } else if (ret == 2) { - break; - } - } - - return 0; -} - -static void log_callback(const enum ggml_log_level level, const char * text, void * p) { - const Opt * opt = static_cast(p); - if (opt->verbose || level == GGML_LOG_LEVEL_ERROR) { - printe("%s", text); - } -} - -static std::string read_pipe_data() { - std::ostringstream result; - result << std::cin.rdbuf(); // Read all data from std::cin - return result.str(); -} - -static void ctrl_c_handling() { -#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) - struct sigaction sigint_action; - sigint_action.sa_handler = sigint_handler; - sigemptyset(&sigint_action.sa_mask); - sigint_action.sa_flags = 0; - sigaction(SIGINT, &sigint_action, NULL); -#elif defined(_WIN32) - auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { - return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false; - }; - SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); -#endif -} - -int main(int argc, const char ** argv) { - ctrl_c_handling(); - Opt opt; - const int ret = opt.init(argc, argv); - if (ret == 2) { - return 0; - } else if (ret) { - return 1; - } - - if (!is_stdin_a_terminal()) { - if (!opt.user.empty()) { - opt.user += "\n\n"; - } - - opt.user += read_pipe_data(); - } - - llama_log_set(log_callback, &opt); - LlamaData llama_data; - if (llama_data.init(opt)) { - return 1; - } - - if (chat_loop(llama_data, opt)) { - return 1; - } - - return 0; -} From ae9f8df77882716b1702df2bed8919499e64cc28 Mon Sep 17 00:00:00 2001 From: R Date: Wed, 7 Jan 2026 16:57:42 +0100 Subject: [PATCH 09/27] fix(docker): add missing libglvnd libraries to Vulkan image (#18664) Add libglvnd0, libgl1, libglx0, libegl1, libgles2 to the Vulkan Dockerfile base image. These libraries are required by mesa-vulkan-drivers to properly initialize the Vulkan ICD and detect GPU devices. Without these libraries, vkEnumeratePhysicalDevices() returns an empty list, resulting in "ggml_vulkan: No devices found." error. Fixes #17761 --- .devops/vulkan.Dockerfile | 1 + 1 file changed, 1 insertion(+) diff --git a/.devops/vulkan.Dockerfile b/.devops/vulkan.Dockerfile index b37b4f277d..89831ed5c2 100644 --- a/.devops/vulkan.Dockerfile +++ b/.devops/vulkan.Dockerfile @@ -33,6 +33,7 @@ FROM ubuntu:$UBUNTU_VERSION AS base RUN apt-get update \ && apt-get install -y libgomp1 curl libvulkan1 mesa-vulkan-drivers \ + libglvnd0 libgl1 libglx0 libegl1 libgles2 \ && apt autoremove -y \ && apt clean -y \ && rm -rf /tmp/* /var/tmp/* \ From f5245b5e4eff4f0d5624dca39b0bd612da7111ff Mon Sep 17 00:00:00 2001 From: Oliver Walsh Date: Wed, 7 Jan 2026 21:32:44 +0000 Subject: [PATCH 10/27] cuda : fix build on cuda 12.8 (#18672) compute121 requires 12.9 Signed-off-by: Oliver Walsh --- ggml/src/ggml-cuda/CMakeLists.txt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index dcc004134d..d313c1ac9a 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -47,7 +47,10 @@ if (CUDAToolkit_FOUND) # check Modules/Internal/CMakeCUDAArchitecturesValidate.cmake in the CMake git repository instead. # However, the architectures 120a-real and 121a-real should work with basically any CMake version and # until the release of e.g. Rubin there is no benefit to shipping virtual architectures for Blackwell. - list(APPEND CMAKE_CUDA_ARCHITECTURES 120a-real 121a-real) + list(APPEND CMAKE_CUDA_ARCHITECTURES 120a-real) + endif() + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.9") + list(APPEND CMAKE_CUDA_ARCHITECTURES 121a-real) endif() endif() endif() From 7e16fef085e8727d534b21a148aa70ebe94f23a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Wed, 7 Jan 2026 22:34:51 +0100 Subject: [PATCH 11/27] convert : more variants of rope_theta config entries (#18668) --- convert_hf_to_gguf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0a8bac0e2d..386e2a7e52 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -771,8 +771,8 @@ class TextModel(ModelBase): self.rope_parameters = self.hparams.get("rope_parameters", self.hparams.get("rope_scaling")) or {} - rope_theta = self.find_hparam(["rope_theta", "global_rope_theta", "rotary_emb_base"], optional=True) - local_rope_theta = self.find_hparam(["local_rope_theta", "rope_local_theta", "swa_rope_theta", "rope_local_base_freq"], optional=True) + rope_theta = self.find_hparam(["global_rope_theta", "rope_global_theta", "rope_theta_global", "rope_theta", "rotary_emb_base"], optional=True) + local_rope_theta = self.find_hparam(["local_rope_theta", "rope_local_theta", "rope_theta_local", "swa_rope_theta", "rope_local_base_freq"], optional=True) # Ensure "rope_theta" and "rope_type" is mirrored in rope_parameters if "full_attention" not in self.rope_parameters and "sliding_attention" not in self.rope_parameters: From 5b8844ae531d8ff09c1c00a2022293d5b674c787 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Wed, 7 Jan 2026 22:35:34 +0100 Subject: [PATCH 12/27] scripts : fix repos cloned with .git extension (#18669) --- scripts/pr2wt.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/pr2wt.sh b/scripts/pr2wt.sh index 22251339ac..36ccde2f34 100755 --- a/scripts/pr2wt.sh +++ b/scripts/pr2wt.sh @@ -34,6 +34,7 @@ url_origin=$(git config --get remote.origin.url) || { } org_repo=$(echo $url_origin | cut -d/ -f4-) +org_repo=${org_repo%.git} echo "org/repo: $org_repo" From 568371a7264c30ad4583f1859cb815dfc0bc14fa Mon Sep 17 00:00:00 2001 From: shaofeiqi <109865877+shaofeiqi@users.noreply.github.com> Date: Wed, 7 Jan 2026 22:04:50 -0800 Subject: [PATCH 13/27] opencl: add FILL op support (#18682) --- ggml/src/ggml-opencl/CMakeLists.txt | 1 + ggml/src/ggml-opencl/ggml-opencl.cpp | 57 ++++++++++++++++++++++++++++ ggml/src/ggml-opencl/kernels/fill.cl | 17 +++++++++ 3 files changed, 75 insertions(+) create mode 100644 ggml/src/ggml-opencl/kernels/fill.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 2a4b79eb6a..f666f08098 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -57,6 +57,7 @@ set(GGML_OPENCL_KERNELS add add_id argsort + fill clamp cpy cvt diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 353f6a4b46..472e2df50a 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -489,6 +489,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_gelu_quick, kernel_gelu_quick_4; cl_kernel kernel_relu; cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16; + cl_kernel kernel_fill; cl_kernel kernel_clamp; cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick, kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16; @@ -787,6 +788,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // fill + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "fill.cl.h" + }; +#else + const std::string kernel_src = read_file("fill.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_fill = clCreateKernel(prog, "kernel_fill_f32", &err), err)); + GGML_LOG_CONT("."); + + CL_CHECK(clReleaseProgram(prog)); + } + // clamp { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3104,6 +3123,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te default: return false; } + case GGML_OP_FILL: + return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op); case GGML_OP_CLAMP: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SOFT_MAX: @@ -5860,6 +5881,36 @@ static void ggml_cl_sigmoid(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); } +static void ggml_cl_fill(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src0); + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + float v = 0.0f; + memcpy(&v, ((int32_t *) dst->op_params), sizeof(float)); + + const int64_t n = ggml_nelements(dst); + + cl_kernel kernel = backend_ctx->kernel_fill; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(float), &v)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(float), &n)); + + size_t local_work_size[1] = { 256 }; + size_t global_work_size[1] = { ((size_t)n + local_work_size[0] - 1) / local_work_size[0] * local_work_size[0] }; + + backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst); +} + static void ggml_cl_clamp(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -9595,6 +9646,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_glu; break; + case GGML_OP_FILL: + if (!any_on_device) { + return false; + } + func = ggml_cl_fill; + break; case GGML_OP_CLAMP: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/fill.cl b/ggml/src/ggml-opencl/kernels/fill.cl new file mode 100644 index 0000000000..9b73938d93 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/fill.cl @@ -0,0 +1,17 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// fill +//------------------------------------------------------------------------------ +__kernel void kernel_fill_f32( + __global float *dst, + ulong offsetd, + float v, + int n + +) { + dst = (global float*)((global char*)dst + offsetd); + if(get_global_id(0) < n){ + dst[get_global_id(0)] = v; + } +} From 2038101bd9b1dcf45b5410b969fbc5206e25d993 Mon Sep 17 00:00:00 2001 From: Julius Tischbein Date: Thu, 8 Jan 2026 07:35:30 +0100 Subject: [PATCH 14/27] llama : add `use_direct_io` flag for model loading (#18166) * Adding --direct-io flag for model loading * Fixing read_raw() calls * Fixing Windows read_raw_at * Changing type off_t to size_t for windows and Renaming functions * disable direct io when mmap is explicitly enabled * Use read_raw_unsafe when upload_backend is available, not functional on some devices with Vulkan and SYCL * Fallback to std::fread in case O_DIRECT fails due to bad address * Windows: remove const keywords and unused functions * Update src/llama-mmap.cpp Co-authored-by: Georgi Gerganov --------- Co-authored-by: jtischbein Co-authored-by: Georgi Gerganov --- common/arg.cpp | 13 +++- common/common.cpp | 1 + common/common.h | 3 +- examples/diffusion/diffusion-cli.cpp | 1 + include/llama.h | 1 + src/llama-mmap.cpp | 111 +++++++++++++++++---------- src/llama-mmap.h | 9 ++- src/llama-model-loader.cpp | 22 ++++-- src/llama-model-loader.h | 2 + src/llama-model.cpp | 4 +- src/llama-quant.cpp | 2 +- src/llama.cpp | 2 +- 12 files changed, 118 insertions(+), 53 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index e7966d9d5c..26c790c7e0 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2088,11 +2088,22 @@ common_params_context common_params_parser_init(common_params & params, llama_ex add_opt(common_arg( {"--mmap"}, {"--no-mmap"}, - string_format("whether to memory-map model (if disabled, slower load but may reduce pageouts if not using mlock) (default: %s)", params.use_mmap ? "enabled" : "disabled"), + string_format("whether to memory-map model. Explicitly enabling mmap disables direct-io. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: %s)", params.use_mmap ? "enabled" : "disabled"), [](common_params & params, bool value) { params.use_mmap = value; + if (value) { + params.use_direct_io = false; // disable direct io when mmap is explicitly enabled + } } ).set_env("LLAMA_ARG_MMAP")); + add_opt(common_arg( + {"-dio", "--direct-io"}, + {"-ndio", "--no-direct-io"}, + string_format("use DirectIO if available. Takes precedence over --mmap (default: %s)", params.use_direct_io ? "enabled" : "disabled"), + [](common_params & params, bool value) { + params.use_direct_io = value; + } + ).set_env("LLAMA_ARG_DIO")); add_opt(common_arg( {"--numa"}, "TYPE", "attempt optimizations that help on some NUMA systems\n" diff --git a/common/common.cpp b/common/common.cpp index 41b2b6833e..34fa3b5a42 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1366,6 +1366,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { mparams.split_mode = params.split_mode; mparams.tensor_split = params.tensor_split; mparams.use_mmap = params.use_mmap; + mparams.use_direct_io = params.use_direct_io; mparams.use_mlock = params.use_mlock; mparams.check_tensors = params.check_tensors; mparams.use_extra_bufts = !params.no_extra_bufts; diff --git a/common/common.h b/common/common.h index d6fd0d37a9..d55a6b71fb 100644 --- a/common/common.h +++ b/common/common.h @@ -428,7 +428,8 @@ struct common_params { bool kv_unified = false; // enable unified KV cache bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix - bool use_mmap = true; // use mmap for faster loads + bool use_mmap = true; // enable mmap to use filesystem cache + bool use_direct_io = true; // read from disk without buffering for faster model loading bool use_mlock = false; // use mlock to keep model in memory bool verbose_prompt = false; // print prompt tokens before generation bool display_prompt = true; // print prompt before generation diff --git a/examples/diffusion/diffusion-cli.cpp b/examples/diffusion/diffusion-cli.cpp index 273942a165..d50f754092 100644 --- a/examples/diffusion/diffusion-cli.cpp +++ b/examples/diffusion/diffusion-cli.cpp @@ -553,6 +553,7 @@ int main(int argc, char ** argv) { model_params.n_gpu_layers = params.n_gpu_layers; model_params.devices = params.devices.data(); model_params.use_mmap = params.use_mmap; + model_params.use_direct_io = params.use_direct_io; model_params.use_mlock = params.use_mlock; model_params.check_tensors = params.check_tensors; diff --git a/include/llama.h b/include/llama.h index 05cb653254..edc4c871a1 100644 --- a/include/llama.h +++ b/include/llama.h @@ -309,6 +309,7 @@ extern "C" { // Keep the booleans together to avoid misalignment during copy-by-value. bool vocab_only; // only load the vocabulary, no weights bool use_mmap; // use mmap if possible + bool use_direct_io; // use direct io, takes precedence over use_mmap bool use_mlock; // force system to keep model in RAM bool check_tensors; // validate model tensor data bool use_extra_bufts; // use extra buffer types (used for weight repacking) diff --git a/src/llama-mmap.cpp b/src/llama-mmap.cpp index 232005e140..2da857b3aa 100644 --- a/src/llama-mmap.cpp +++ b/src/llama-mmap.cpp @@ -110,7 +110,7 @@ struct llama_file::impl { } } - void read_raw(void * ptr, size_t len) const { + void read_raw(void * ptr, size_t len) { size_t bytes_read = 0; while (bytes_read < len) { size_t chunk_size = std::min(len - bytes_read, 64*1024*1024); @@ -127,7 +127,7 @@ struct llama_file::impl { } } - uint32_t read_u32() const { + uint32_t read_u32() { uint32_t val; read_raw(&val, sizeof(val)); return val; @@ -154,8 +154,8 @@ struct llama_file::impl { write_raw(&val, sizeof(val)); } - void read_aligned_chunk(size_t offset, void * dest, size_t size) const { - throw std::runtime_error("DirectIO is not implemented on Windows."); + bool has_direct_io() const { + return true; } ~impl() { @@ -164,33 +164,45 @@ struct llama_file::impl { } } #else - impl(const char * fname, const char * mode, [[maybe_unused]] const bool use_direct_io = false) { + impl(const char * fname, const char * mode, [[maybe_unused]] const bool use_direct_io = false) : fname(fname) { #ifdef __linux__ // Try unbuffered I/O for read only if (use_direct_io && std::strcmp(mode, "rb") == 0) { - fd = open(fname, O_RDONLY | O_DIRECT); - - if (fd != -1) { - struct stat file_stats{}; - fstat(fd, &file_stats); - - size = file_stats.st_size; - alignment = file_stats.st_blksize; - - off_t ret = lseek(fd, 0, SEEK_SET); - if (ret == -1) { - throw std::runtime_error(format("seek error: %s", strerror(errno))); - } + if (init_fd()) { return; } - - LLAMA_LOG_WARN("Failed to open model %s with error: %s. Falling back to buffered I/O", - fname, strerror(errno)); + LLAMA_LOG_WARN("Failed to open file '%s' with error: %s. Falling back to buffered I/O", + fname, strerror(errno)); } #endif - fp = ggml_fopen(fname, mode); + init_fp(mode); + } + +#ifdef __linux__ + bool init_fd() { + fd = open(fname.c_str(), O_RDONLY | O_DIRECT); + + if (fd != -1) { + struct stat file_stats{}; + fstat(fd, &file_stats); + + size = file_stats.st_size; + alignment = file_stats.st_blksize; + + off_t ret = lseek(fd, 0, SEEK_SET); + if (ret == -1) { + throw std::runtime_error(format("seek error: %s", strerror(errno))); + } + return true; + } + return false; + } +#endif + + void init_fp(const char * mode) { + fp = ggml_fopen(fname.c_str(), mode); if (fp == NULL) { - throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); + throw std::runtime_error(format("failed to open %s: %s", fname.c_str(), strerror(errno))); } seek(0, SEEK_END); size = tell(); @@ -226,7 +238,7 @@ struct llama_file::impl { } } - void read_raw(void * ptr, size_t len) const { + void read_raw_unsafe(void * ptr, size_t len) { if (len == 0) { return; } @@ -249,6 +261,17 @@ struct llama_file::impl { if (errno == EINTR) { continue; // Interrupted by signal, retry } + // Fallback to std::fread in case the DMA controller cannot access the buffer + if (errno == EFAULT) { + auto curr_off = tell(); + close(fd); + fd = -1; + alignment = 1; + init_fp("rb"); + seek(curr_off, SEEK_SET); + read_raw_unsafe(ptr, len); + return; + } throw std::runtime_error(format("read error: %s", strerror(errno))); } if (ret == 0) { @@ -266,7 +289,8 @@ struct llama_file::impl { } } - void read_aligned_chunk(size_t offset, void * dest, size_t size) const { + void read_aligned_chunk(void * dest, size_t size) { + size_t offset = tell(); off_t aligned_offset = offset & ~(alignment - 1); off_t offset_from_alignment = offset - aligned_offset; size_t bytes_to_read = (offset_from_alignment + size + alignment - 1) & ~(alignment - 1); @@ -283,13 +307,21 @@ struct llama_file::impl { std::unique_ptr buffer(raw_buffer); seek(aligned_offset, SEEK_SET); - read_raw(buffer.get(), bytes_to_read); + read_raw_unsafe(buffer.get(), bytes_to_read); uintptr_t actual_data = reinterpret_cast(buffer.get()) + offset_from_alignment; memcpy(dest, reinterpret_cast(actual_data), size); } - uint32_t read_u32() const { + void read_raw(void * ptr, size_t len) { + if (has_direct_io()) { + read_aligned_chunk(ptr, len); + } else { + read_raw_unsafe(ptr, len); + } + } + + uint32_t read_u32() { uint32_t ret; read_raw(&ret, sizeof(ret)); return ret; @@ -310,6 +342,10 @@ struct llama_file::impl { write_raw(&val, sizeof(val)); } + bool has_direct_io() const { + return fd != -1 && alignment > 1; + } + ~impl() { if (fd != -1) { close(fd); @@ -318,17 +354,9 @@ struct llama_file::impl { } } int fd = -1; + std::string fname; #endif - void read_raw_at(void * ptr, size_t len, size_t offset) const { - if (alignment != 1) { - read_aligned_chunk(offset, ptr, len); - } else { - seek(offset, SEEK_SET); - read_raw(ptr, len); - } - } - size_t read_alignment() const { return alignment; } @@ -347,6 +375,7 @@ size_t llama_file::tell() const { return pimpl->tell(); } size_t llama_file::size() const { return pimpl->size; } size_t llama_file::read_alignment() const { return pimpl->read_alignment(); } +bool llama_file::has_direct_io() const { return pimpl->has_direct_io(); } int llama_file::file_id() const { #ifdef _WIN32 @@ -361,10 +390,14 @@ int llama_file::file_id() const { } void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); } -void llama_file::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); } -void llama_file::read_raw_at(void * ptr, size_t len, size_t offset) const { pimpl->read_raw_at(ptr, len, offset); } +void llama_file::read_raw(void * ptr, size_t len) { pimpl->read_raw(ptr, len); } +#ifdef _WIN32 +void llama_file::read_raw_unsafe(void * ptr, size_t len) { pimpl->read_raw(ptr, len); } +#else +void llama_file::read_raw_unsafe(void * ptr, size_t len) { pimpl->read_raw_unsafe(ptr, len); } +#endif -uint32_t llama_file::read_u32() const { return pimpl->read_u32(); } +uint32_t llama_file::read_u32() { return pimpl->read_u32(); } void llama_file::write_raw(const void * ptr, size_t len) const { pimpl->write_raw(ptr, len); } void llama_file::write_u32(uint32_t val) const { pimpl->write_u32(val); } diff --git a/src/llama-mmap.h b/src/llama-mmap.h index 729aac164b..29ce4d2468 100644 --- a/src/llama-mmap.h +++ b/src/llama-mmap.h @@ -24,15 +24,16 @@ struct llama_file { void seek(size_t offset, int whence) const; - void read_raw(void * ptr, size_t len) const; - void read_raw_at(void * ptr, size_t len, size_t offset) const; - void read_aligned_chunk(size_t offset, void * dest, size_t size) const; - uint32_t read_u32() const; + void read_raw(void * ptr, size_t len); + void read_raw_unsafe(void * ptr, size_t len); + void read_aligned_chunk(void * dest, size_t size); + uint32_t read_u32(); void write_raw(const void * ptr, size_t len) const; void write_u32(uint32_t val) const; size_t read_alignment() const; + bool has_direct_io() const; private: struct impl; std::unique_ptr pimpl; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 5003b4fbf5..e66febaa02 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -495,6 +495,7 @@ llama_model_loader::llama_model_loader( const std::string & fname, std::vector & splits, bool use_mmap, + bool use_direct_io, bool check_tensors, bool no_alloc, const llama_model_kv_override * param_overrides_p, @@ -527,9 +528,17 @@ llama_model_loader::llama_model_loader( get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); llm_kv = LLM_KV(llm_arch_from_string(arch_name)); - files.emplace_back(new llama_file(fname.c_str(), "rb", !use_mmap)); + files.emplace_back(new llama_file(fname.c_str(), "rb", use_direct_io)); contexts.emplace_back(ctx); + use_direct_io = use_direct_io && files.back()->has_direct_io(); + + // Disable mmap in case Direct I/O is enabled and available + if (use_direct_io && use_mmap) { + use_mmap = false; + LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__); + } + // Save tensors data offset of the main file. // For subsidiary files, `meta` tensor data offset must not be used, // so we build a unified tensors index for weights. @@ -595,7 +604,7 @@ llama_model_loader::llama_model_loader( } } - files.emplace_back(new llama_file(fname_split, "rb", !use_mmap)); + files.emplace_back(new llama_file(fname_split, "rb", use_direct_io)); contexts.emplace_back(ctx); // Save tensors data offset info of the shard. @@ -739,6 +748,7 @@ llama_model_loader::llama_model_loader( } this->use_mmap = use_mmap; + this->use_direct_io = use_direct_io; this->check_tensors = check_tensors; this->no_alloc = no_alloc; } @@ -1100,7 +1110,8 @@ bool llama_model_loader::load_all_data( const auto & file = files.at(weight->idx); if (ggml_backend_buffer_is_host(cur->buffer)) { - file->read_raw_at(cur->data, n_size, weight->offs); + file->seek(weight->offs, SEEK_SET); + file->read_raw(cur->data, n_size); if (check_tensors) { validation_result.emplace_back(std::async(std::launch::async, [cur, n_size] { return std::make_pair(cur, ggml_validate_row_data(cur->type, cur->data, n_size)); @@ -1132,7 +1143,7 @@ bool llama_model_loader::load_all_data( ggml_backend_event_synchronize(events[buffer_idx]); // Read aligned chunk from file - file->read_raw(reinterpret_cast(ptr_dest_aligned), read_size); + file->read_raw_unsafe(reinterpret_cast(ptr_dest_aligned), read_size); // Calculate actual data portion (excluding alignment padding) uintptr_t ptr_data = ptr_dest_aligned; @@ -1162,7 +1173,8 @@ bool llama_model_loader::load_all_data( } } else { read_buf.resize(n_size); - file->read_raw_at(read_buf.data(), n_size, weight->offs); + file->seek(weight->offs, SEEK_SET); + file->read_raw(read_buf.data(), n_size); ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size); if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) { throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h index d13299ad3f..65953dd3d5 100644 --- a/src/llama-model-loader.h +++ b/src/llama-model-loader.h @@ -70,6 +70,7 @@ struct llama_model_loader { size_t n_bytes = 0; bool use_mmap = false; + bool use_direct_io = false; bool check_tensors; bool no_alloc; @@ -97,6 +98,7 @@ struct llama_model_loader { const std::string & fname, std::vector & splits, // optional, only need if the split does not follow naming scheme bool use_mmap, + bool use_direct_io, bool check_tensors, bool no_alloc, const llama_model_kv_override * param_overrides_p, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 04c48b5fd3..7ac59846bb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2440,7 +2440,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const bool use_mmap_buffer = true; - LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false"); + LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s, direct_io = %s)\n", + __func__, ml.use_mmap ? "true" : "false", ml.use_direct_io ? "true" : "false"); // build a list of buffer types for the CPU and GPU devices pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host); @@ -7973,6 +7974,7 @@ llama_model_params llama_model_default_params() { /*.kv_overrides =*/ nullptr, /*.vocab_only =*/ false, /*.use_mmap =*/ true, + /*.use_direct_io =*/ true, /*.use_mlock =*/ false, /*.check_tensors =*/ false, /*.use_extra_bufts =*/ true, diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index bc4b05c3b5..048d65a75c 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -596,7 +596,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } std::vector splits = {}; - llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); + llama_model_loader ml(fname_inp, splits, use_mmap, /*use_direct_io*/ true, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching llama_model model(llama_model_default_params()); diff --git a/src/llama.cpp b/src/llama.cpp index 0162ae8d58..dfefb3d2b5 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -794,7 +794,7 @@ static int llama_model_load(const std::string & fname, std::vector model.t_start_us = tm.t_start_us; try { - llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides); + llama_model_loader ml(fname, splits, params.use_mmap, params.use_direct_io, params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides); ml.print_info(); From df7fb92170f1c6ed08bf0943d6d8bf1191543a95 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 8 Jan 2026 09:29:15 +0100 Subject: [PATCH 15/27] model-conversion : remove -st targets for converted model (#18689) This commit removes the '-st` make target for running the converted embedding model. The motivation for this is that the pooling type is now part of the .gguf metdata of the model and this is used by llama-debug when running the model. So there is no need to specify the pooling type separately any more. The commit also adds an option to specify the type of normalization applied to the output embeddings when running the converted model. And the readme documentation has been updated to reflect these changes. --- examples/model-conversion/Makefile | 7 ++----- examples/model-conversion/README.md | 18 ++++++++++++++---- .../scripts/embedding/run-converted-model.sh | 14 +++++--------- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/examples/model-conversion/Makefile b/examples/model-conversion/Makefile index f8dc525a77..359b9cfd8e 100644 --- a/examples/model-conversion/Makefile +++ b/examples/model-conversion/Makefile @@ -138,16 +138,13 @@ embedding-run-original-model-st: embedding-run-original-model embedding-run-converted-model: @./scripts/embedding/run-converted-model.sh $(CONVERTED_EMBEDDING_MODEL) \ $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \ - $(if $(USE_POOLING),--pooling) - -embedding-run-converted-model-st: USE_POOLING=1 -embedding-run-converted-model-st: embedding-run-converted-model + $(if $(EMBD_NORMALIZE),--embd-normalize "$(EMBD_NORMALIZE)") embedding-verify-logits: embedding-run-original-model embedding-run-converted-model @./scripts/embedding/compare-embeddings-logits.sh \ $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") -embedding-verify-logits-st: embedding-run-original-model-st embedding-run-converted-model-st +embedding-verify-logits-st: embedding-run-original-model-st embedding-run-converted-model @./scripts/embedding/compare-embeddings-logits.sh \ $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") diff --git a/examples/model-conversion/README.md b/examples/model-conversion/README.md index 8163b306b4..637870a5c1 100644 --- a/examples/model-conversion/README.md +++ b/examples/model-conversion/README.md @@ -198,14 +198,13 @@ model, and the other is a text file which allows for manual visual inspection. #### Using SentenceTransformer with numbered layers For models that have numbered SentenceTransformer layers (01_Pooling, 02_Dense, -03_Dense, 04_Normalize), use the `-st` targets to apply all these layers: +03_Dense, 04_Normalize), these will be applied automatically when running the +converted model but currently there is a separate target to run the original +version: ```console # Run original model with SentenceTransformer (applies all numbered layers) (venv) $ make embedding-run-original-model-st - -# Run converted model with pooling enabled -(venv) $ make embedding-run-converted-model-st ``` This will use the SentenceTransformer library to load and run the model, which @@ -213,6 +212,17 @@ automatically applies all the numbered layers in the correct order. This is particularly useful when comparing with models that should include these additional transformation layers beyond just the base model output. +The type of normalization can be specified for the converted model but is not +strictly necessary as the verification uses cosine similarity and the magnitude +of the output vectors does not affect this. But the normalization type can be +specified as an argument to the target which might be useful for manual +inspection: +```console +(venv) $ make embedding-verify-logits-st EMBD_NORMALIZE=1 +``` +The original model will apply the normalization according to the normalization +layer specified in the modules.json configuration file. + ### Model conversion After updates have been made to [gguf-py](../../gguf-py) to add support for the new model the model can be converted to GGUF format using the following command: diff --git a/examples/model-conversion/scripts/embedding/run-converted-model.sh b/examples/model-conversion/scripts/embedding/run-converted-model.sh index 5d264b0663..84625cec3d 100755 --- a/examples/model-conversion/scripts/embedding/run-converted-model.sh +++ b/examples/model-conversion/scripts/embedding/run-converted-model.sh @@ -5,7 +5,7 @@ set -e # Parse command line arguments CONVERTED_MODEL="" PROMPTS_FILE="" -USE_POOLING="" +EMBD_NORMALIZE="2" while [[ $# -gt 0 ]]; do case $1 in @@ -13,9 +13,9 @@ while [[ $# -gt 0 ]]; do PROMPTS_FILE="$2" shift 2 ;; - --pooling) - USE_POOLING="1" - shift + --embd-normalize) + EMBD_NORMALIZE="$2" + shift 2 ;; *) if [ -z "$CONVERTED_MODEL" ]; then @@ -51,8 +51,4 @@ fi echo $CONVERTED_MODEL cmake --build ../../build --target llama-debug -j8 -if [ -n "$USE_POOLING" ]; then - ../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding --pooling mean -p "$PROMPT" --save-logits -else - ../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding --pooling none -p "$PROMPT" --save-logits -fi +../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding -p "$PROMPT" --save-logits --embd-normalize $EMBD_NORMALIZE From 9c142e3a2a8f1c7415511bd9d24f4790ce2dac88 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 8 Jan 2026 09:29:53 +0100 Subject: [PATCH 16/27] model-conversion : add warn about transformers mismatch (#18691) This commit adds a check comparing the installed transformers library with the transformers version that the original model supports. This check will be performed upon a model verification failure and prints a warning/hint to the user suggesting to install the correct version of the transformers library. The motivation for this change is that it is possible for the model verification to fail due to differences in the transformers library used and it might not be obvious that this could be the cause of the failure. With this warning the correct version can be checked and hopefully save time troubleshooting the cause of the verification failure. --- examples/model-conversion/Makefile | 2 +- .../scripts/causal/compare-logits.py | 10 ++-- .../model-conversion/scripts/utils/common.py | 54 +++++++++++++++++++ .../scripts/utils/semantic_check.py | 7 ++- 4 files changed, 63 insertions(+), 10 deletions(-) diff --git a/examples/model-conversion/Makefile b/examples/model-conversion/Makefile index 359b9cfd8e..3b0505911d 100644 --- a/examples/model-conversion/Makefile +++ b/examples/model-conversion/Makefile @@ -61,7 +61,7 @@ causal-run-converted-model: @CONVERTED_MODEL="$(CONVERTED_MODEL)" ./scripts/causal/run-converted-model.sh causal-verify-logits: causal-run-original-model causal-run-converted-model - @./scripts/causal/compare-logits.py + @MODEL_PATH="$(MODEL_PATH)" ./scripts/causal/compare-logits.py @MODEL_PATH="$(MODEL_PATH)" ./scripts/utils/check-nmse.py -m ${MODEL_PATH} causal-run-original-embeddings: diff --git a/examples/model-conversion/scripts/causal/compare-logits.py b/examples/model-conversion/scripts/causal/compare-logits.py index 1a933207d5..83bd14c659 100755 --- a/examples/model-conversion/scripts/causal/compare-logits.py +++ b/examples/model-conversion/scripts/causal/compare-logits.py @@ -3,10 +3,11 @@ import sys import numpy as np from pathlib import Path +import os # Add utils directory to path for direct script execution sys.path.insert(0, str(Path(__file__).parent.parent / "utils")) -from common import get_model_name_from_env_path, compare_tokens # type: ignore[import-not-found] +from common import get_model_name_from_env_path, compare_tokens, exit_with_warning # type: ignore[import-not-found] def quick_logits_check(pytorch_file, llamacpp_file): """Lightweight sanity check before NMSE""" @@ -38,6 +39,7 @@ def quick_logits_check(pytorch_file, llamacpp_file): return True def main(): + model_path = os.environ.get('MODEL_PATH') model_name = get_model_name_from_env_path('MODEL_PATH') data_dir = Path("data") pytorch_file = data_dir / f"pytorch-{model_name}.bin" @@ -62,8 +64,7 @@ def main(): print("šŸ” Token Comparison Check") print("=" * 40) if not compare_tokens(f"pytorch-{model_name}", f"llamacpp-{llamacpp_model_name}"): - print("\nāŒ Token mismatch detected") - sys.exit(1) + exit_with_warning("\nāŒ Token mismatch detected", model_path) print() print("šŸ” GGML Model Validation for model ", model_name) @@ -80,8 +81,7 @@ def main(): print(" Ok to proceed with NMSE check...") sys.exit(0) else: - print(f"āŒ NOK: Top 10 predictions don't match - generation will differ") - sys.exit(1) + exit_with_warning(f"āŒ NOK: Top 10 predictions don't match - generation will differ", model_path) if __name__ == "__main__": main() diff --git a/examples/model-conversion/scripts/utils/common.py b/examples/model-conversion/scripts/utils/common.py index 71761127bb..aa4bab2601 100644 --- a/examples/model-conversion/scripts/utils/common.py +++ b/examples/model-conversion/scripts/utils/common.py @@ -3,6 +3,9 @@ import os import sys import torch +import transformers +import json +import textwrap import numpy as np from pathlib import Path @@ -243,3 +246,54 @@ def compare_tokens(original, converted, type_suffix="", output_dir="data"): print(f" ... and {len(mismatches) - num_to_show} more mismatches") return False + + +def show_version_warning(current_version, model_version): + if not model_version: + return False + + try: + from packaging.version import parse, InvalidVersion + try: + return parse(current_version) < parse(model_version) + except InvalidVersion: + return current_version != model_version + except ImportError: + return current_version != model_version + +def get_model_transformers_version(model_path): + if not model_path: + return None + + config_path = Path(model_path) / "config.json" + if not config_path.is_file(): + return None + + try: + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + return config.get("transformers_version") + except (IOError, json.JSONDecodeError) as e: + print(f"Warning: Could not read or parse {config_path}: {e}", file=sys.stderr) + return None + +def exit_with_warning(message, model_path): + print(message) + + if model_path and transformers is not None: + model_transformers_version = get_model_transformers_version(model_path) + transformers_version = transformers.__version__ + if show_version_warning(transformers_version, model_transformers_version): + warning_message = f""" + ===================================================================== + Verification failure might be due to a transformers version mismatch: + + Current transformers version: {transformers_version} + Model's required version : {model_transformers_version} + + Consider installing the version specified by the model's config: + pip install transformers=={model_transformers_version} + ===================================================================== + """ + print(textwrap.dedent(warning_message)) + sys.exit(1) diff --git a/examples/model-conversion/scripts/utils/semantic_check.py b/examples/model-conversion/scripts/utils/semantic_check.py index 38b03ce4d2..73e20ea489 100644 --- a/examples/model-conversion/scripts/utils/semantic_check.py +++ b/examples/model-conversion/scripts/utils/semantic_check.py @@ -7,7 +7,7 @@ import importlib from pathlib import Path from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModel -from common import compare_tokens # type: ignore[import-not-found] +from common import compare_tokens, exit_with_warning # type: ignore[import-not-found] unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME') @@ -174,8 +174,7 @@ def main(): print("=" * 70) data_dir = python_emb_path.parent if not compare_tokens(python_model_name, cpp_model_name, type_suffix="-embeddings", output_dir=str(data_dir)): - print("\nāŒ Token mismatch detected") - exit(1) + exit_with_warning("\nāŒ Token mismatch detected", args.model_path) print() # Single prompt detailed comparison @@ -237,7 +236,7 @@ def main(): elif avg_cross_sim > 0.70: print("āš ļø FAIR: Models have some differences") else: - print("āŒ POOR: Models are significantly different") + exit_with_warning("āŒ POOR: Models are significantly different", args.model_path) if __name__ == "__main__": main() From 9a5724dee2457d58e506268efcb1d2286498cf3d Mon Sep 17 00:00:00 2001 From: Doctor Shotgun <126566557+DocShotgun@users.noreply.github.com> Date: Thu, 8 Jan 2026 01:03:21 -0800 Subject: [PATCH 17/27] ggml: add env var GGML_OP_OFFLOAD_MIN_BATCH (#18535) * ggml: add env var GGML_OP_OFFLOAD_MIN_BATCH * makes the min_batch_size for triggering op offload configurable via env var, defaulting to the prior hardcoded value of 32 * ggml: read GGML_OP_OFFLOAD_MIN_BATCH once and store to dev ctx * cann: forward declaration of device context struct * cann: move offload op check after device context declaration * cuda: fix whitespace Co-authored-by: Aman Gupta --------- Co-authored-by: Aman Gupta --- ggml/src/ggml-cann/ggml-cann.cpp | 44 +++++++++++++------------ ggml/src/ggml-cuda/ggml-cuda.cu | 9 ++--- ggml/src/ggml-metal/ggml-metal-device.h | 2 ++ ggml/src/ggml-metal/ggml-metal-device.m | 2 ++ ggml/src/ggml-metal/ggml-metal.cpp | 7 ++-- ggml/src/ggml-sycl/ggml-sycl.cpp | 8 +++-- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 11 ++++--- 7 files changed, 45 insertions(+), 38 deletions(-) diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 162d238ae4..d7a93848df 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2541,27 +2541,6 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) { return buft->iface.get_name == ggml_backend_cann_buffer_type_name; } -/** - * @brief Determines if a tensor operation should be offloaded to the CANN - * backend. - * - * This function checks if a given tensor operation should be offloaded to the - * CANN backend based on the operation type and the size of the tensor. It - * returns true if the second dimension (ne[1]) of the tensor is greater than or - * equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS. - * - * @param backend Pointer to the CANN backend. - * @param op Pointer to the tensor operation to check. - * @return bool Returns true if the operation should be offloaded, otherwise - * false. - */ -static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { - const int min_batch_size = 32; - GGML_UNUSED(dev); - - return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS; -} - /** * @brief Records an event on the CANN backend stream. * @@ -2637,6 +2616,7 @@ struct ggml_backend_cann_device_context { int device; std::string name; std::string description; + int op_offload_min_batch_size; }; static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) { @@ -2713,6 +2693,26 @@ static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type( return ggml_backend_cann_host_buffer_type(); } +/** + * @brief Determines if a tensor operation should be offloaded to the CANN + * backend. + * + * This function checks if a given tensor operation should be offloaded to the + * CANN backend based on the operation type and the size of the tensor. It + * returns true if the second dimension (ne[1]) of the tensor is greater than or + * equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS. + * + * @param backend Pointer to the CANN backend. + * @param op Pointer to the tensor operation to check. + * @return bool Returns true if the operation should be offloaded, otherwise + * false. + */ +static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context; + + return op->ne[1] >= dev_ctx->op_offload_min_batch_size && op->op != GGML_OP_GET_ROWS; +} + /** * @brief Creates a new event for the CANN backend device. * @@ -2829,12 +2829,14 @@ ggml_backend_reg_t ggml_backend_cann_reg() { if (!initialized) { aclInit(nullptr); ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context; + const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; for (int i = 0; i < ggml_cann_info().device_count; i++) { ggml_backend_cann_device_context * dev_ctx = new ggml_backend_cann_device_context(); dev_ctx->description = aclrtGetSocName(); dev_ctx->device = i; dev_ctx->name = GGML_CANN_NAME + std::to_string(i); + dev_ctx->op_offload_min_batch_size = min_batch_size; ggml_cann_set_device(i); ggml_backend_dev_t dev = new ggml_backend_device{ /* .iface = */ ggml_backend_cann_device_interface, /* .reg = */ ®, diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index bac69cdd1c..f021de1d74 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4122,6 +4122,7 @@ struct ggml_backend_cuda_device_context { std::string name; std::string description; std::string pci_bus_id; + int op_offload_min_batch_size; }; static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { @@ -4676,11 +4677,9 @@ static int64_t get_op_batch_size(const ggml_tensor * op) { } static bool ggml_backend_cuda_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { - const int min_batch_size = 32; + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context; - return get_op_batch_size(op) >= min_batch_size; - - GGML_UNUSED(dev); + return get_op_batch_size(op) >= dev_ctx->op_offload_min_batch_size; } static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_t dev) { @@ -4848,6 +4847,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { std::lock_guard lock(mutex); if (!initialized) { ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context; + const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; for (int i = 0; i < ggml_cuda_info().device_count; i++) { ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context; @@ -4861,6 +4861,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { char pci_bus_id[16] = {}; snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); dev_ctx->pci_bus_id = pci_bus_id; + dev_ctx->op_offload_min_batch_size = min_batch_size; ggml_backend_dev_t dev = new ggml_backend_device { /* .iface = */ ggml_backend_cuda_device_interface, diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index d983b666ca..9c3b001487 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -219,6 +219,8 @@ struct ggml_metal_device_props { bool use_shared_buffers; bool supports_gpu_family_apple7; + + int op_offload_min_batch_size; }; ggml_metal_device_t ggml_metal_device_init(void); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 59badd0043..ff899a8170 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -782,6 +782,8 @@ ggml_metal_device_t ggml_metal_device_init(void) { dev->props.supports_gpu_family_apple7 = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7]; + dev->props.op_offload_min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; + dev->props.max_buffer_size = dev->mtl_device.maxBufferLength; dev->props.max_working_set_size = dev->mtl_device.recommendedMaxWorkingSetSize; dev->props.max_theadgroup_memory_size = dev->mtl_device.maxThreadgroupMemoryLength; diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index 70bf6f3d98..56b59f0afd 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -625,14 +625,11 @@ static int64_t get_op_batch_size(const ggml_tensor * op) { } static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { - const int min_batch_size = 32; + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; return (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) && - get_op_batch_size(op) >= min_batch_size; - - GGML_UNUSED(dev); - GGML_UNUSED(op); + get_op_batch_size(op) >= ggml_metal_device_get_props(ctx_dev)->op_offload_min_batch_size; } static ggml_backend_device_i ggml_backend_metal_device_i = { diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index e996d98be8..8f8176b678 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4286,6 +4286,7 @@ struct ggml_backend_sycl_device_context { int device; std::string name; std::string description; + int op_offload_min_batch_size; }; static const char * ggml_backend_sycl_device_get_name(ggml_backend_dev_t dev) { @@ -4674,9 +4675,8 @@ static int64_t get_op_batch_size(const ggml_tensor * op) { } static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { - const int min_batch_size = 32; - return get_op_batch_size(op) >= min_batch_size; - GGML_UNUSED(dev); + ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context; + return get_op_batch_size(op) >= sycl_ctx->op_offload_min_batch_size; } static ggml_backend_event_t @@ -4799,6 +4799,7 @@ ggml_backend_reg_t ggml_backend_sycl_reg() { std::lock_guard lock(mutex); if (!initialized) { ggml_backend_sycl_reg_context * ctx = new ggml_backend_sycl_reg_context; + const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; for (int i = 0; i < ggml_sycl_info().device_count; i++) { ggml_backend_sycl_device_context * dev_ctx = new ggml_backend_sycl_device_context; @@ -4812,6 +4813,7 @@ ggml_backend_reg_t ggml_backend_sycl_reg() { prop, dpct::dev_mgr::instance().get_device(i)))); dev_ctx->description = prop.get_name(); + dev_ctx->op_offload_min_batch_size = min_batch_size; ggml_backend_dev_t dev = new ggml_backend_device { /* .iface = */ ggml_backend_sycl_device_interface, diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d68735a040..4d3c085f67 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -14249,6 +14249,7 @@ struct ggml_backend_vk_device_context { std::string description; bool is_integrated_gpu; std::string pci_bus_id; + int op_offload_min_batch_size; }; static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { @@ -14820,12 +14821,10 @@ static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_ba } static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { - const int min_batch_size = 32; + ggml_backend_vk_device_context * dev_ctx = (ggml_backend_vk_device_context *)dev->context; - return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) || - (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID); - - UNUSED(dev); + return (op->ne[1] >= dev_ctx->op_offload_min_batch_size && op->op != GGML_OP_GET_ROWS) || + (op->ne[2] >= dev_ctx->op_offload_min_batch_size && op->op == GGML_OP_MUL_MAT_ID); } static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t dev) { @@ -14951,6 +14950,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, static std::mutex mutex; std::lock_guard lock(mutex); if (!initialized) { + const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; char desc[256]; @@ -14960,6 +14960,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, ctx->description = desc; ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i); + ctx->op_offload_min_batch_size = min_batch_size; devices.push_back(new ggml_backend_device { /* .iface = */ ggml_backend_vk_device_i, /* .reg = */ reg, From 64848deb1887532003575db9bdf46df700c3e495 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 8 Jan 2026 10:07:58 +0100 Subject: [PATCH 18/27] llama-fit-params: free memory target per device (#18679) --- common/arg.cpp | 28 ++++++++++--- common/common.cpp | 2 +- common/common.h | 14 ++++--- include/llama.h | 2 +- src/llama.cpp | 74 ++++++++++++++++++++++----------- tools/fit-params/fit-params.cpp | 2 +- 6 files changed, 83 insertions(+), 39 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 26c790c7e0..9c0e6fbe78 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2255,7 +2255,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex std::vector split_arg{ it, {} }; if (split_arg.size() >= llama_max_devices()) { throw std::invalid_argument( - string_format("got %d input configs, but system only has %d devices", (int)split_arg.size(), (int)llama_max_devices()) + string_format("got %zu input configs, but system only has %zu devices", split_arg.size(), llama_max_devices()) ); } for (size_t i = 0; i < llama_max_devices(); ++i) { @@ -2295,10 +2295,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_env("LLAMA_ARG_FIT")); add_opt(common_arg( - { "-fitt", "--fit-target" }, "MiB", - string_format("target margin per device for --fit option, default: %zu", params.fit_params_target/(1024*1024)), - [](common_params & params, int value) { - params.fit_params_target = value * size_t(1024*1024); + { "-fitt", "--fit-target" }, "MiB0,MiB1,MiB2,...", + string_format("target margin per device for --fit, comma-separated list of values, " + "single value is broadcast across all devices, default: %zu", params.fit_params_target[0]/(1024*1024)), + [](common_params & params, const std::string & value) { + std::string arg_next = value; + + // split string by , and / + const std::regex regex{ R"([,/]+)" }; + std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 }; + std::vector split_arg{ it, {} }; + if (split_arg.size() >= llama_max_devices()) { + throw std::invalid_argument( + string_format("got %zu input configs, but system only has %zu devices", split_arg.size(), llama_max_devices()) + ); + } + if (split_arg.size() == 1) { + std::fill(params.fit_params_target.begin(), params.fit_params_target.end(), std::stoul(split_arg[0]) * 1024*1024); + return; + } + for (size_t i = 0; i < split_arg.size(); i++) { + params.fit_params_target[i] = std::stoul(split_arg[i]) * 1024*1024; + } } ).set_env("LLAMA_ARG_FIT_TARGET")); add_opt(common_arg( diff --git a/common/common.cpp b/common/common.cpp index 34fa3b5a42..744f0b4eeb 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1097,7 +1097,7 @@ common_init_result::common_init_result(common_params & params) : if (params.fit_params) { LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__); llama_params_fit(params.model.path.c_str(), &mparams, &cparams, - params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx, + params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx, params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR); } diff --git a/common/common.h b/common/common.h index d55a6b71fb..7794c0268b 100644 --- a/common/common.h +++ b/common/common.h @@ -332,12 +332,14 @@ struct common_params { // offload params std::vector devices; // devices to use for offloading - int32_t n_gpu_layers = -1; // number of layers to store in VRAM, -1 is auto, <= -2 is all - int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors - float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs - bool fit_params = true; // whether to fit unset model/context parameters to free device memory - size_t fit_params_target = 1024 * 1024*1024; // margin per device in bytes for fitting parameters to free memory - int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use + int32_t n_gpu_layers = -1; // number of layers to store in VRAM, -1 is auto, <= -2 is all + int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors + float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs + bool fit_params = true; // whether to fit unset model/context parameters to free device memory + int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use + + // margin per device in bytes for fitting parameters to free memory: + std::vector fit_params_target = std::vector(llama_max_devices(), 1024 * 1024*1024); enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs diff --git a/include/llama.h b/include/llama.h index edc4c871a1..12e4e57d0e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -495,7 +495,7 @@ extern "C" { struct llama_context_params * cparams, float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements - size_t margin, // margin of memory to leave per device in bytes + size_t * margins, // margins of memory to leave per device in bytes uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log diff --git a/src/llama.cpp b/src/llama.cpp index dfefb3d2b5..33f51a2389 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -147,9 +147,8 @@ class llama_params_fit_exception : public std::runtime_error { static void llama_params_fit_impl( const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams, float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides, - size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) { + size_t * margins_s, uint32_t n_ctx_min, enum ggml_log_level log_level) { constexpr int64_t MiB = 1024*1024; - const int64_t margin = margin_s; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits typedef std::vector dmds_t; const llama_model_params default_mparams = llama_model_default_params(); @@ -168,6 +167,12 @@ static void llama_params_fit_impl( return; } + std::vector margins; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits + margins.reserve(nd); + for (size_t id = 0; id < nd; id++) { + margins.push_back(margins_s[id]); + } + std::vector dev_names; { dev_names.reserve(nd); @@ -187,9 +192,10 @@ static void llama_params_fit_impl( int64_t sum_free = 0; int64_t sum_projected_free = 0; - int64_t min_projected_free = INT64_MAX; int64_t sum_projected_used = 0; int64_t sum_projected_model = 0; + std::vector projected_free_per_device; + projected_free_per_device.reserve(nd); if (nd > 1) { LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__); @@ -199,45 +205,63 @@ static void llama_params_fit_impl( const int64_t projected_used = dmd.mb.total(); const int64_t projected_free = dmd.free - projected_used; + projected_free_per_device.push_back(projected_free); sum_free += dmd.free; sum_projected_used += projected_used; sum_projected_free += projected_free; - min_projected_free = std::min(min_projected_free, projected_free); sum_projected_model += dmd.mb.model; if (nd > 1) { - LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " %s\n", - __func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, std::abs(projected_free)/MiB, - projected_free >= 0 ? "surplus" : "deficit"); + LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " free vs. target of %6" PRId64 "\n", + __func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, projected_free/MiB, margins[id]/MiB); } } assert(sum_free >= 0 && sum_projected_used >= 0); LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n", __func__, sum_projected_used/MiB, sum_free/MiB); - if (min_projected_free >= margin) { - if (nd == 1) { + if (nd == 1) { + if (projected_free_per_device[0] >= margins[0]) { LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n", - __func__, min_projected_free/MiB, margin/MiB); + __func__, projected_free_per_device[0]/MiB, margins[0]/MiB); + return; + } + } else { + bool changes_needed = false; + for (size_t id = 0; id < nd; id++) { + if (projected_free_per_device[id] < margins[id]) { + changes_needed = true; + break; + } + } + if (!changes_needed) { + LLAMA_LOG_INFO("%s: targets for free memory can be met on all devices, no changes needed\n", __func__); return; } - LLAMA_LOG_INFO("%s: will leave at least %" PRId64 " >= %" PRId64 " MiB of free memory on all devices, no changes needed\n", - __func__, min_projected_free/MiB, margin/MiB); - return; } // step 2: try reducing memory use by reducing the context size { - int64_t global_surplus = sum_projected_free - int64_t(nd)*margin; + int64_t global_surplus = sum_projected_free; + for (size_t id = 0; id < nd; id++) { + global_surplus -= margins[id]; + } if (global_surplus < 0) { - LLAMA_LOG_INFO(nd == 1 ? - "%s: cannot fulfill margin of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n" : - "%s: cannot fulfill margin of %" PRId64 " MiB on all devices, need to use %" PRId64 " MiB less in total\n", - __func__, margin/MiB, -global_surplus/MiB); + if (nd == 1) { + LLAMA_LOG_INFO("%s: cannot meet free memory target of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n", + __func__, margins[0]/MiB, -global_surplus/MiB); + } else { + LLAMA_LOG_INFO( + "%s: cannot meet free memory targets on all devices, need to use %" PRId64 " MiB less in total\n", + __func__, -global_surplus/MiB); + } if (cparams->n_ctx == 0) { if (hp_nct > n_ctx_min) { - int64_t sum_used_target = sum_free - nd*margin_s; + int64_t sum_used_target = sum_free; + for (size_t id = 0; id < nd; id++) { + sum_used_target -= margins[id]; + } if (nd > 1) { // for multiple devices we need to be more conservative in terms of how much context we think can fit: // - for dense models only whole layers can be assigned to devices @@ -448,9 +472,9 @@ static void llama_params_fit_impl( const dmds_t dmds_cpu_moe = llama_get_device_memory_data( path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - for (const llama_device_memory_data & dmd : dmds_cpu_moe) { - global_surplus_cpu_moe += dmd.free; - global_surplus_cpu_moe -= int64_t(dmd.mb.total()) + margin; + for (size_t id = 0; id < nd; id++) { + global_surplus_cpu_moe += dmds_cpu_moe[id].free; + global_surplus_cpu_moe -= int64_t(dmds_cpu_moe[id].mb.total()) + margins[id]; } if (global_surplus_cpu_moe > 0) { @@ -469,7 +493,7 @@ static void llama_params_fit_impl( std::vector targets; // maximum acceptable memory use per device targets.reserve(nd); for (size_t id = 0; id < nd; id++) { - targets.push_back(dmds_full[id].free - margin); + targets.push_back(dmds_full[id].free - margins[id]); LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB); } @@ -701,11 +725,11 @@ static void llama_params_fit_impl( enum llama_params_fit_status llama_params_fit( const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams, float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides, - size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) { + size_t * margins, uint32_t n_ctx_min, enum ggml_log_level log_level) { const int64_t t0_us = llama_time_us(); llama_params_fit_status status = LLAMA_PARAMS_FIT_STATUS_SUCCESS; try { - llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margin_s, n_ctx_min, log_level); + llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margins, n_ctx_min, log_level); LLAMA_LOG_INFO("%s: successfully fit params to free device memory\n", __func__); } catch (const llama_params_fit_exception & e) { LLAMA_LOG_WARN("%s: failed to fit params to free device memory: %s\n", __func__, e.what()); diff --git a/tools/fit-params/fit-params.cpp b/tools/fit-params/fit-params.cpp index c7e7748ca9..f9d9cb34c7 100644 --- a/tools/fit-params/fit-params.cpp +++ b/tools/fit-params/fit-params.cpp @@ -27,7 +27,7 @@ int main(int argc, char ** argv) { auto mparams = common_model_params_to_llama(params); auto cparams = common_context_params_to_llama(params); const llama_params_fit_status status = llama_params_fit(params.model.path.c_str(), &mparams, &cparams, - params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx, + params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx, params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR); if (status != LLAMA_PARAMS_FIT_STATUS_SUCCESS) { LOG_ERR("%s: failed to fit CLI arguments to free memory, exiting...\n", __func__); From 945bf106276c664498cf6c95731aa6ceb43657ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EB=8F=84=EB=A1=9C=EB=A1=9C=EB=8F=84=EB=A1=9C=EB=98=90?= <60079918+dororodoroddo@users.noreply.github.com> Date: Thu, 8 Jan 2026 19:37:45 +0900 Subject: [PATCH 19/27] metal : add MoE kernel specialization for ne20=5 (#18667) Add template specialization for kernel_mul_mm_id_map0 with ne20=5 to support models using 5 active experts (e.g., VAETKI). --- ggml/src/ggml-metal/ggml-metal.metal | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 67b30e0d93..16d17d26af 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -9148,6 +9148,7 @@ typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t; template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>; template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>; template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>; +template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>; template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>; template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>; template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>; From f2f6c88067e0da7cd1696fb6f78b0d1f5021262e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 8 Jan 2026 13:40:23 +0200 Subject: [PATCH 20/27] scripts : support chaining commands in pr2wt.sh (#18671) --- scripts/pr2wt.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/pr2wt.sh b/scripts/pr2wt.sh index 36ccde2f34..7970bec371 100755 --- a/scripts/pr2wt.sh +++ b/scripts/pr2wt.sh @@ -9,6 +9,7 @@ # sample usage: # ./scripts/pr2wt.sh 12345 # ./scripts/pr2wt.sh 12345 opencode +# ./scripts/pr2wt.sh 12345 "cmake -B build && cmake --build build" function usage() { echo "usage: $0 [cmd]" @@ -46,7 +47,7 @@ head_ref=$(echo "$meta" | jq -r '.head.ref') echo "url: $url_remote" echo "head_ref: $head_ref" -git remote rm pr/${PR} +git remote rm pr/${PR} 2> /dev/null git remote add pr/${PR} $url_remote git fetch pr/${PR} $head_ref @@ -62,5 +63,5 @@ echo "git worktree created in $wt_path" # if a command was provided, execute it if [[ $# -eq 2 ]]; then cd ../$dir-pr-$PR - exec $2 + eval "$2" fi From 55abc393552f3f2097f168cb6db4dc495a514d56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Thu, 8 Jan 2026 13:53:54 +0100 Subject: [PATCH 21/27] vendor : update cpp-httplib to 0.30.0 (#18660) * vendor : update cpp-httplib to 0.30.0 * common : allow custom headers when downloading --- common/arg.h | 8 - common/download.cpp | 86 +- common/download.h | 23 +- scripts/sync_vendor.py | 2 +- tests/test-arg-parser.cpp | 1 + tools/server/server-common.cpp | 4 +- vendor/cpp-httplib/httplib.cpp | 1486 ++++++++++++++++++++++++++------ vendor/cpp-httplib/httplib.h | 1124 +++++++++++++++++++----- 8 files changed, 2188 insertions(+), 546 deletions(-) diff --git a/common/arg.h b/common/arg.h index a1b6a14e67..55782a158d 100644 --- a/common/arg.h +++ b/common/arg.h @@ -129,11 +129,3 @@ void common_params_add_preset_options(std::vector & args); // initialize argument parser context - used by test-arg-parser and preset common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); - -struct common_remote_params { - std::vector headers; - long timeout = 0; // CURLOPT_TIMEOUT, in seconds ; 0 means no timeout - long max_size = 0; // max size of the response ; unlimited if 0 ; max is 2GB -}; -// get remote file content, returns -std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params); diff --git a/common/download.cpp b/common/download.cpp index ef87472560..6f56b5518f 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -308,7 +308,8 @@ static bool common_download_head(CURL * curl, // download one single file from remote URL to local path static bool common_download_file_single_online(const std::string & url, const std::string & path, - const std::string & bearer_token) { + const std::string & bearer_token, + const common_header_list & custom_headers) { static const int max_attempts = 3; static const int retry_delay_seconds = 2; for (int i = 0; i < max_attempts; ++i) { @@ -330,6 +331,11 @@ static bool common_download_file_single_online(const std::string & url, common_load_model_from_url_headers headers; curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); curl_slist_ptr http_headers; + + for (const auto & h : custom_headers) { + std::string s = h.first + ": " + h.second; + http_headers.ptr = curl_slist_append(http_headers.ptr, s.c_str()); + } const bool was_perform_successful = common_download_head(curl.get(), http_headers, url, bearer_token); if (!was_perform_successful) { head_request_ok = false; @@ -454,8 +460,10 @@ std::pair> common_remote_get_content(const std::string & curl_easy_setopt(curl.get(), CURLOPT_MAXFILESIZE, params.max_size); } http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); + for (const auto & header : params.headers) { - http_headers.ptr = curl_slist_append(http_headers.ptr, header.c_str()); + std::string header_ = header.first + ": " + header.second; + http_headers.ptr = curl_slist_append(http_headers.ptr, header_.c_str()); } curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); @@ -619,7 +627,8 @@ static bool common_pull_file(httplib::Client & cli, // download one single file from remote URL to local path static bool common_download_file_single_online(const std::string & url, const std::string & path, - const std::string & bearer_token) { + const std::string & bearer_token, + const common_header_list & custom_headers) { static const int max_attempts = 3; static const int retry_delay_seconds = 2; @@ -629,6 +638,9 @@ static bool common_download_file_single_online(const std::string & url, if (!bearer_token.empty()) { default_headers.insert({"Authorization", "Bearer " + bearer_token}); } + for (const auto & h : custom_headers) { + default_headers.emplace(h.first, h.second); + } cli.set_default_headers(default_headers); const bool file_exists = std::filesystem::exists(path); @@ -734,13 +746,9 @@ std::pair> common_remote_get_content(const std::string auto [cli, parts] = common_http_client(url); httplib::Headers headers = {{"User-Agent", "llama-cpp"}}; + for (const auto & header : params.headers) { - size_t pos = header.find(':'); - if (pos != std::string::npos) { - headers.emplace(header.substr(0, pos), header.substr(pos + 1)); - } else { - headers.emplace(header, ""); - } + headers.emplace(header.first, header.second); } if (params.timeout > 0) { @@ -772,9 +780,10 @@ std::pair> common_remote_get_content(const std::string static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, - bool offline) { + bool offline, + const common_header_list & headers) { if (!offline) { - return common_download_file_single_online(url, path, bearer_token); + return common_download_file_single_online(url, path, bearer_token, headers); } if (!std::filesystem::exists(path)) { @@ -788,13 +797,24 @@ static bool common_download_file_single(const std::string & url, // download multiple files from remote URLs to local paths // the input is a vector of pairs -static bool common_download_file_multiple(const std::vector> & urls, const std::string & bearer_token, bool offline) { +static bool common_download_file_multiple(const std::vector> & urls, + const std::string & bearer_token, + bool offline, + const common_header_list & headers) { // Prepare download in parallel std::vector> futures_download; + futures_download.reserve(urls.size()); + for (auto const & item : urls) { - futures_download.push_back(std::async(std::launch::async, [bearer_token, offline](const std::pair & it) -> bool { - return common_download_file_single(it.first, it.second, bearer_token, offline); - }, item)); + futures_download.push_back( + std::async( + std::launch::async, + [&bearer_token, offline, &headers](const std::pair & it) -> bool { + return common_download_file_single(it.first, it.second, bearer_token, offline, headers); + }, + item + ) + ); } // Wait for all downloads to complete @@ -807,17 +827,17 @@ static bool common_download_file_multiple(const std::vector(hf_repo_with_tag, ':'); std::string tag = parts.size() > 1 ? parts.back() : "latest"; std::string hf_repo = parts[0]; @@ -893,10 +916,10 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag; // headers - std::vector headers; - headers.push_back("Accept: application/json"); + common_header_list headers = custom_headers; + headers.push_back({"Accept", "application/json"}); if (!bearer_token.empty()) { - headers.push_back("Authorization: Bearer " + bearer_token); + headers.push_back({"Authorization", "Bearer " + bearer_token}); } // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response // User-Agent header is already set in common_remote_get_content, no need to set it here @@ -1031,9 +1054,10 @@ std::string common_docker_resolve_model(const std::string & docker) { const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo; std::string manifest_url = url_prefix + "/manifests/" + tag; common_remote_params manifest_params; - manifest_params.headers.push_back("Authorization: Bearer " + token); - manifest_params.headers.push_back( - "Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json"); + manifest_params.headers.push_back({"Authorization", "Bearer " + token}); + manifest_params.headers.push_back({"Accept", + "application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json" + }); auto manifest_res = common_remote_get_content(manifest_url, manifest_params); if (manifest_res.first != 200) { throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first)); @@ -1070,7 +1094,7 @@ std::string common_docker_resolve_model(const std::string & docker) { std::string local_path = fs_get_cache_file(model_filename); const std::string blob_url = url_prefix + "/blobs/" + gguf_digest; - if (!common_download_file_single(blob_url, local_path, token, false)) { + if (!common_download_file_single(blob_url, local_path, token, false, {})) { throw std::runtime_error("Failed to download Docker Model"); } @@ -1084,11 +1108,11 @@ std::string common_docker_resolve_model(const std::string & docker) { #else -common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool) { +common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool, const common_header_list &) { throw std::runtime_error("download functionality is not enabled in this build"); } -bool common_download_model(const common_params_model &, const std::string &, bool) { +bool common_download_model(const common_params_model &, const std::string &, bool, const common_header_list &) { throw std::runtime_error("download functionality is not enabled in this build"); } diff --git a/common/download.h b/common/download.h index d1321e6e90..9ea2093939 100644 --- a/common/download.h +++ b/common/download.h @@ -1,12 +1,21 @@ #pragma once #include +#include struct common_params_model; -// -// download functionalities -// +using common_header = std::pair; +using common_header_list = std::vector; + +struct common_remote_params { + common_header_list headers; + long timeout = 0; // in seconds, 0 means no timeout + long max_size = 0; // unlimited if 0 +}; + +// get remote file content, returns +std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params); struct common_cached_model_info { std::string manifest_path; @@ -41,13 +50,17 @@ struct common_hf_file_res { common_hf_file_res common_get_hf_file( const std::string & hf_repo_with_tag, const std::string & bearer_token, - bool offline); + bool offline, + const common_header_list & headers = {} +); // returns true if download succeeded bool common_download_model( const common_params_model & model, const std::string & bearer_token, - bool offline); + bool offline, + const common_header_list & headers = {} +); // returns list of cached models std::vector common_list_cached_models(); diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index 637f4cdc18..ed6bf1bf4e 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -16,7 +16,7 @@ vendor = { # "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h", "https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h", - "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.28.0/httplib.h": "vendor/cpp-httplib/httplib.h", + "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.30.0/httplib.h": "vendor/cpp-httplib/httplib.h", "https://raw.githubusercontent.com/sheredom/subprocess.h/b49c56e9fe214488493021017bf3954b91c7c1f5/subprocess.h": "vendor/sheredom/subprocess.h", } diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp index e995974a2e..c7be0021be 100644 --- a/tests/test-arg-parser.cpp +++ b/tests/test-arg-parser.cpp @@ -1,5 +1,6 @@ #include "arg.h" #include "common.h" +#include "download.h" #include #include diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index e4a0be44cc..16b0db2983 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -1,10 +1,10 @@ #include "common.h" +#include "download.h" #include "log.h" #include "llama.h" #include "mtmd.h" #include "mtmd-helper.h" #include "chat.h" -#include "arg.h" // for common_remote_get_content; TODO: use download.h only #include "base64.hpp" #include "server-common.h" @@ -779,7 +779,7 @@ static void handle_media( // download remote image // TODO @ngxson : maybe make these params configurable common_remote_params params; - params.headers.push_back("User-Agent: llama.cpp/" + build_info); + params.headers.push_back({"User-Agent", "llama.cpp/" + build_info}); params.max_size = 1024 * 1024 * 10; // 10MB params.timeout = 10; // seconds SRV_INF("downloading image from '%s'\n", url.c_str()); diff --git a/vendor/cpp-httplib/httplib.cpp b/vendor/cpp-httplib/httplib.cpp index b86e6a2310..a437a36ed7 100644 --- a/vendor/cpp-httplib/httplib.cpp +++ b/vendor/cpp-httplib/httplib.cpp @@ -9,7 +9,7 @@ namespace httplib { namespace detail { bool is_hex(char c, int &v) { - if (0x20 <= c && isdigit(c)) { + if (isdigit(c)) { v = c - '0'; return true; } else if ('A' <= c && c <= 'F') { @@ -49,6 +49,90 @@ std::string from_i_to_hex(size_t n) { return ret; } +std::string compute_etag(const FileStat &fs) { + if (!fs.is_file()) { return std::string(); } + + // If mtime cannot be determined (negative value indicates an error + // or sentinel), do not generate an ETag. Returning a neutral / fixed + // value like 0 could collide with a real file that legitimately has + // mtime == 0 (epoch) and lead to misleading validators. + auto mtime_raw = fs.mtime(); + if (mtime_raw < 0) { return std::string(); } + + auto mtime = static_cast(mtime_raw); + auto size = fs.size(); + + return std::string("W/\"") + from_i_to_hex(mtime) + "-" + + from_i_to_hex(size) + "\""; +} + +// Format time_t as HTTP-date (RFC 9110 Section 5.6.7): "Sun, 06 Nov 1994 +// 08:49:37 GMT" This implementation is defensive: it validates `mtime`, checks +// return values from `gmtime_r`/`gmtime_s`, and ensures `strftime` succeeds. +std::string file_mtime_to_http_date(time_t mtime) { + if (mtime < 0) { return std::string(); } + + struct tm tm_buf; +#ifdef _WIN32 + if (gmtime_s(&tm_buf, &mtime) != 0) { return std::string(); } +#else + if (gmtime_r(&mtime, &tm_buf) == nullptr) { return std::string(); } +#endif + char buf[64]; + if (strftime(buf, sizeof(buf), "%a, %d %b %Y %H:%M:%S GMT", &tm_buf) == 0) { + return std::string(); + } + + return std::string(buf); +} + +// Parse HTTP-date (RFC 9110 Section 5.6.7) to time_t. Returns -1 on failure. +time_t parse_http_date(const std::string &date_str) { + struct tm tm_buf; + + // Create a classic locale object once for all parsing attempts + const std::locale classic_locale = std::locale::classic(); + + // Try to parse using std::get_time (C++11, cross-platform) + auto try_parse = [&](const char *fmt) -> bool { + std::istringstream ss(date_str); + ss.imbue(classic_locale); + + memset(&tm_buf, 0, sizeof(tm_buf)); + ss >> std::get_time(&tm_buf, fmt); + + return !ss.fail(); + }; + + // RFC 9110 preferred format (HTTP-date): "Sun, 06 Nov 1994 08:49:37 GMT" + if (!try_parse("%a, %d %b %Y %H:%M:%S")) { + // RFC 850 format: "Sunday, 06-Nov-94 08:49:37 GMT" + if (!try_parse("%A, %d-%b-%y %H:%M:%S")) { + // asctime format: "Sun Nov 6 08:49:37 1994" + if (!try_parse("%a %b %d %H:%M:%S %Y")) { + return static_cast(-1); + } + } + } + +#ifdef _WIN32 + return _mkgmtime(&tm_buf); +#else + return timegm(&tm_buf); +#endif +} + +bool is_weak_etag(const std::string &s) { + // Check if the string is a weak ETag (starts with 'W/"') + return s.size() > 3 && s[0] == 'W' && s[1] == '/' && s[2] == '"'; +} + +bool is_strong_etag(const std::string &s) { + // Check if the string is a strong ETag (starts and ends with '"', at least 2 + // chars) + return s.size() >= 2 && s[0] == '"' && s.back() == '"'; +} + size_t to_utf8(int code, char *buff) { if (code < 0x0080) { buff[0] = static_cast(code & 0x7F); @@ -168,6 +252,15 @@ bool FileStat::is_dir() const { return ret_ >= 0 && S_ISDIR(st_.st_mode); } +time_t FileStat::mtime() const { + return ret_ >= 0 ? static_cast(st_.st_mtime) + : static_cast(-1); +} + +size_t FileStat::size() const { + return ret_ >= 0 ? static_cast(st_.st_size) : 0; +} + std::string encode_path(const std::string &s) { std::string result; result.reserve(s.size()); @@ -209,6 +302,149 @@ std::string file_extension(const std::string &path) { bool is_space_or_tab(char c) { return c == ' ' || c == '\t'; } +template +bool parse_header(const char *beg, const char *end, T fn); + +template +bool parse_header(const char *beg, const char *end, T fn) { + // Skip trailing spaces and tabs. + while (beg < end && is_space_or_tab(end[-1])) { + end--; + } + + auto p = beg; + while (p < end && *p != ':') { + p++; + } + + auto name = std::string(beg, p); + if (!detail::fields::is_field_name(name)) { return false; } + + if (p == end) { return false; } + + auto key_end = p; + + if (*p++ != ':') { return false; } + + while (p < end && is_space_or_tab(*p)) { + p++; + } + + if (p <= end) { + auto key_len = key_end - beg; + if (!key_len) { return false; } + + auto key = std::string(beg, key_end); + auto val = std::string(p, end); + + if (!detail::fields::is_field_value(val)) { return false; } + + if (case_ignore::equal(key, "Location") || + case_ignore::equal(key, "Referer")) { + fn(key, val); + } else { + fn(key, decode_path_component(val)); + } + + return true; + } + + return false; +} + +bool parse_trailers(stream_line_reader &line_reader, Headers &dest, + const Headers &src_headers) { + // NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentions "The chunked + // transfer coding is complete when a chunk with a chunk-size of zero is + // received, possibly followed by a trailer section, and finally terminated by + // an empty line". https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1 + // + // In '7.1.3. Decoding Chunked', however, the pseudo-code in the section + // doesn't care for the existence of the final CRLF. In other words, it seems + // to be ok whether the final CRLF exists or not in the chunked data. + // https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1.3 + // + // According to the reference code in RFC 9112, cpp-httplib now allows + // chunked transfer coding data without the final CRLF. + + // RFC 7230 Section 4.1.2 - Headers prohibited in trailers + thread_local case_ignore::unordered_set prohibited_trailers = { + "transfer-encoding", + "content-length", + "host", + "authorization", + "www-authenticate", + "proxy-authenticate", + "proxy-authorization", + "cookie", + "set-cookie", + "cache-control", + "expect", + "max-forwards", + "pragma", + "range", + "te", + "age", + "expires", + "date", + "location", + "retry-after", + "vary", + "warning", + "content-encoding", + "content-type", + "content-range", + "trailer"}; + + case_ignore::unordered_set declared_trailers; + auto trailer_header = get_header_value(src_headers, "Trailer", "", 0); + if (trailer_header && std::strlen(trailer_header)) { + auto len = std::strlen(trailer_header); + split(trailer_header, trailer_header + len, ',', + [&](const char *b, const char *e) { + const char *kbeg = b; + const char *kend = e; + while (kbeg < kend && (*kbeg == ' ' || *kbeg == '\t')) { + ++kbeg; + } + while (kend > kbeg && (kend[-1] == ' ' || kend[-1] == '\t')) { + --kend; + } + std::string key(kbeg, static_cast(kend - kbeg)); + if (!key.empty() && + prohibited_trailers.find(key) == prohibited_trailers.end()) { + declared_trailers.insert(key); + } + }); + } + + size_t trailer_header_count = 0; + while (strcmp(line_reader.ptr(), "\r\n") != 0) { + if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } + if (trailer_header_count >= CPPHTTPLIB_HEADER_MAX_COUNT) { return false; } + + constexpr auto line_terminator_len = 2; + auto line_beg = line_reader.ptr(); + auto line_end = + line_reader.ptr() + line_reader.size() - line_terminator_len; + + if (!parse_header(line_beg, line_end, + [&](const std::string &key, const std::string &val) { + if (declared_trailers.find(key) != + declared_trailers.end()) { + dest.emplace(key, val); + trailer_header_count++; + } + })) { + return false; + } + + if (!line_reader.getline()) { return false; } + } + + return true; +} + std::pair trim(const char *b, const char *e, size_t left, size_t right) { while (b + left < e && is_space_or_tab(b[left])) { @@ -280,6 +516,42 @@ void split(const char *b, const char *e, char d, size_t m, } } +bool split_find(const char *b, const char *e, char d, size_t m, + std::function fn) { + size_t i = 0; + size_t beg = 0; + size_t count = 1; + + while (e ? (b + i < e) : (b[i] != '\0')) { + if (b[i] == d && count < m) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { + auto found = fn(&b[r.first], &b[r.second]); + if (found) { return true; } + } + beg = i + 1; + count++; + } + i++; + } + + if (i) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { + auto found = fn(&b[r.first], &b[r.second]); + if (found) { return true; } + } + } + + return false; +} + +bool split_find(const char *b, const char *e, char d, + std::function fn) { + return split_find(b, e, d, (std::numeric_limits::max)(), + std::move(fn)); +} + stream_line_reader::stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size) : strm_(strm), fixed_buffer_(fixed_buffer), @@ -1892,6 +2164,27 @@ bool zstd_decompressor::decompress(const char *data, size_t data_length, } #endif +std::unique_ptr +create_decompressor(const std::string &encoding) { + std::unique_ptr decompressor; + + if (encoding == "gzip" || encoding == "deflate") { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + decompressor = detail::make_unique(); +#endif + } else if (encoding.find("br") != std::string::npos) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + decompressor = detail::make_unique(); +#endif + } else if (encoding == "zstd" || encoding.find("zstd") != std::string::npos) { +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + decompressor = detail::make_unique(); +#endif + } + + return decompressor; +} + bool is_prohibited_header_name(const std::string &name) { using udl::operator""_t; @@ -1928,53 +2221,6 @@ const char *get_header_value(const Headers &headers, return def; } -template -bool parse_header(const char *beg, const char *end, T fn) { - // Skip trailing spaces and tabs. - while (beg < end && is_space_or_tab(end[-1])) { - end--; - } - - auto p = beg; - while (p < end && *p != ':') { - p++; - } - - auto name = std::string(beg, p); - if (!detail::fields::is_field_name(name)) { return false; } - - if (p == end) { return false; } - - auto key_end = p; - - if (*p++ != ':') { return false; } - - while (p < end && is_space_or_tab(*p)) { - p++; - } - - if (p <= end) { - auto key_len = key_end - beg; - if (!key_len) { return false; } - - auto key = std::string(beg, key_end); - auto val = std::string(p, end); - - if (!detail::fields::is_field_value(val)) { return false; } - - if (case_ignore::equal(key, "Location") || - case_ignore::equal(key, "Referer")) { - fn(key, val); - } else { - fn(key, decode_path_component(val)); - } - - return true; - } - - return false; -} - bool read_headers(Stream &strm, Headers &headers) { const auto bufsiz = 2048; char buf[bufsiz]; @@ -2026,10 +2272,18 @@ bool read_content_with_length(Stream &strm, size_t len, ContentReceiverWithProgress out) { char buf[CPPHTTPLIB_RECV_BUFSIZ]; + detail::BodyReader br; + br.stream = &strm; + br.content_length = len; + br.chunked = false; + br.bytes_read = 0; + br.last_error = Error::Success; + size_t r = 0; while (r < len) { auto read_len = static_cast(len - r); - auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + auto to_read = (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ); + auto n = detail::read_body_content(&strm, br, buf, to_read); if (n <= 0) { return false; } if (!out(buf, static_cast(n), r, len)) { return false; } @@ -2089,125 +2343,35 @@ template ReadContentResult read_content_chunked(Stream &strm, T &x, size_t payload_max_length, ContentReceiverWithProgress out) { - const auto bufsiz = 16; - char buf[bufsiz]; + detail::ChunkedDecoder dec(strm); - stream_line_reader line_reader(strm, buf, bufsiz); - - if (!line_reader.getline()) { return ReadContentResult::Error; } - - unsigned long chunk_len; + char buf[CPPHTTPLIB_RECV_BUFSIZ]; size_t total_len = 0; - while (true) { - char *end_ptr; - chunk_len = std::strtoul(line_reader.ptr(), &end_ptr, 16); + for (;;) { + size_t chunk_offset = 0; + size_t chunk_total = 0; + auto n = dec.read_payload(buf, sizeof(buf), chunk_offset, chunk_total); + if (n < 0) { return ReadContentResult::Error; } - if (end_ptr == line_reader.ptr()) { return ReadContentResult::Error; } - if (chunk_len == ULONG_MAX) { return ReadContentResult::Error; } + if (n == 0) { + if (!dec.parse_trailers_into(x.trailers, x.headers)) { + return ReadContentResult::Error; + } + return ReadContentResult::Success; + } - if (chunk_len == 0) { break; } - - // Check if adding this chunk would exceed the payload limit if (total_len > payload_max_length || - payload_max_length - total_len < chunk_len) { + payload_max_length - total_len < static_cast(n)) { return ReadContentResult::PayloadTooLarge; } - total_len += chunk_len; - - if (!read_content_with_length(strm, chunk_len, nullptr, out)) { + if (!out(buf, static_cast(n), chunk_offset, chunk_total)) { return ReadContentResult::Error; } - if (!line_reader.getline()) { return ReadContentResult::Error; } - - if (strcmp(line_reader.ptr(), "\r\n") != 0) { - return ReadContentResult::Error; - } - - if (!line_reader.getline()) { return ReadContentResult::Error; } + total_len += static_cast(n); } - - assert(chunk_len == 0); - - // NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentions "The chunked - // transfer coding is complete when a chunk with a chunk-size of zero is - // received, possibly followed by a trailer section, and finally terminated by - // an empty line". https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1 - // - // In '7.1.3. Decoding Chunked', however, the pseudo-code in the section - // does't care for the existence of the final CRLF. In other words, it seems - // to be ok whether the final CRLF exists or not in the chunked data. - // https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1.3 - // - // According to the reference code in RFC 9112, cpp-httplib now allows - // chunked transfer coding data without the final CRLF. - if (!line_reader.getline()) { return ReadContentResult::Success; } - - // RFC 7230 Section 4.1.2 - Headers prohibited in trailers - thread_local case_ignore::unordered_set prohibited_trailers = { - // Message framing - "transfer-encoding", "content-length", - - // Routing - "host", - - // Authentication - "authorization", "www-authenticate", "proxy-authenticate", - "proxy-authorization", "cookie", "set-cookie", - - // Request modifiers - "cache-control", "expect", "max-forwards", "pragma", "range", "te", - - // Response control - "age", "expires", "date", "location", "retry-after", "vary", "warning", - - // Payload processing - "content-encoding", "content-type", "content-range", "trailer"}; - - // Parse declared trailer headers once for performance - case_ignore::unordered_set declared_trailers; - if (has_header(x.headers, "Trailer")) { - auto trailer_header = get_header_value(x.headers, "Trailer", "", 0); - auto len = std::strlen(trailer_header); - - split(trailer_header, trailer_header + len, ',', - [&](const char *b, const char *e) { - std::string key(b, e); - if (prohibited_trailers.find(key) == prohibited_trailers.end()) { - declared_trailers.insert(key); - } - }); - } - - size_t trailer_header_count = 0; - while (strcmp(line_reader.ptr(), "\r\n") != 0) { - if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { - return ReadContentResult::Error; - } - - // Check trailer header count limit - if (trailer_header_count >= CPPHTTPLIB_HEADER_MAX_COUNT) { - return ReadContentResult::Error; - } - - // Exclude line terminator - constexpr auto line_terminator_len = 2; - auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; - - parse_header(line_reader.ptr(), end, - [&](const std::string &key, const std::string &val) { - if (declared_trailers.find(key) != declared_trailers.end()) { - x.trailers.emplace(key, val); - trailer_header_count++; - } - }); - - if (!line_reader.getline()) { return ReadContentResult::Error; } - } - - return ReadContentResult::Success; } bool is_chunked_transfer_encoding(const Headers &headers) { @@ -2223,27 +2387,13 @@ bool prepare_content_receiver(T &x, int &status, std::string encoding = x.get_header_value("Content-Encoding"); std::unique_ptr decompressor; - if (encoding == "gzip" || encoding == "deflate") { -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - decompressor = detail::make_unique(); -#else - status = StatusCode::UnsupportedMediaType_415; - return false; -#endif - } else if (encoding.find("br") != std::string::npos) { -#ifdef CPPHTTPLIB_BROTLI_SUPPORT - decompressor = detail::make_unique(); -#else - status = StatusCode::UnsupportedMediaType_415; - return false; -#endif - } else if (encoding == "zstd") { -#ifdef CPPHTTPLIB_ZSTD_SUPPORT - decompressor = detail::make_unique(); -#else - status = StatusCode::UnsupportedMediaType_415; - return false; -#endif + if (!encoding.empty()) { + decompressor = detail::create_decompressor(encoding); + if (!decompressor) { + // Unsupported encoding or no support compiled in + status = StatusCode::UnsupportedMediaType_415; + return false; + } } if (decompressor) { @@ -2329,7 +2479,7 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, ssize_t write_request_line(Stream &strm, const std::string &method, const std::string &path) { std::string s = method; - s += " "; + s += ' '; s += path; s += " HTTP/1.1\r\n"; return strm.write(s.data(), s.size()); @@ -2338,7 +2488,7 @@ ssize_t write_request_line(Stream &strm, const std::string &method, ssize_t write_response_line(Stream &strm, int status) { std::string s = "HTTP/1.1 "; s += std::to_string(status); - s += " "; + s += ' '; s += httplib::status_message(status); s += "\r\n"; return strm.write(s.data(), s.size()); @@ -2601,8 +2751,8 @@ bool redirect(T &cli, Request &req, Response &res, auto ret = cli.send(new_req, new_res, error); if (ret) { - req = new_req; - res = new_res; + req = std::move(new_req); + res = std::move(new_res); if (res.location.empty()) { res.location = location; } } @@ -2613,9 +2763,9 @@ std::string params_to_query_str(const Params ¶ms) { std::string query; for (auto it = params.begin(); it != params.end(); ++it) { - if (it != params.begin()) { query += "&"; } + if (it != params.begin()) { query += '&'; } query += encode_query_component(it->first); - query += "="; + query += '='; query += encode_query_component(it->second); } return query; @@ -2648,6 +2798,38 @@ void parse_query_text(const std::string &s, Params ¶ms) { parse_query_text(s.data(), s.size(), params); } +// Normalize a query string by decoding and re-encoding each key/value pair +// while preserving the original parameter order. This avoids double-encoding +// and ensures consistent encoding without reordering (unlike Params which +// uses std::multimap and sorts keys). +std::string normalize_query_string(const std::string &query) { + std::string result; + split(query.data(), query.data() + query.size(), '&', + [&](const char *b, const char *e) { + std::string key; + std::string val; + divide(b, static_cast(e - b), '=', + [&](const char *lhs_data, std::size_t lhs_size, + const char *rhs_data, std::size_t rhs_size) { + key.assign(lhs_data, lhs_size); + val.assign(rhs_data, rhs_size); + }); + + if (!key.empty()) { + auto dec_key = decode_query_component(key); + auto dec_val = decode_query_component(val); + + if (!result.empty()) { result += '&'; } + result += encode_query_component(dec_key); + if (!val.empty() || std::find(b, e, '=') != e) { + result += '='; + result += encode_query_component(dec_val); + } + } + }); + return result; +} + bool parse_multipart_boundary(const std::string &content_type, std::string &boundary) { auto boundary_keyword = "boundary="; @@ -2840,7 +3022,7 @@ bool parse_accept_header(const std::string &s, return; } - entries.push_back(accept_entry); + entries.push_back(std::move(accept_entry)); }); // Return false if any invalid entry was found @@ -2857,8 +3039,8 @@ bool parse_accept_header(const std::string &s, // Extract sorted media types content_types.reserve(entries.size()); - for (const auto &entry : entries) { - content_types.push_back(entry.media_type); + for (auto &entry : entries) { + content_types.push_back(std::move(entry.media_type)); } return true; @@ -2869,7 +3051,7 @@ public: FormDataParser() = default; void set_boundary(std::string &&boundary) { - boundary_ = boundary; + boundary_ = std::move(boundary); dash_boundary_crlf_ = dash_ + boundary_ + crlf_; crlf_dash_boundary_ = crlf_ + dash_ + boundary_; } @@ -3342,9 +3524,9 @@ std::string make_content_range_header_field( std::string field = "bytes "; field += std::to_string(st); - field += "-"; + field += '-'; field += std::to_string(ed); - field += "/"; + field += '/'; field += std::to_string(content_length); return field; } @@ -3721,7 +3903,7 @@ bool parse_www_authenticate(const Response &res, static_cast(m.length(2))) : s.substr(static_cast(m.position(3)), static_cast(m.length(3))); - auth[key] = val; + auth[std::move(key)] = std::move(val); } return true; } @@ -3734,7 +3916,7 @@ class ContentProviderAdapter { public: explicit ContentProviderAdapter( ContentProviderWithoutLength &&content_provider) - : content_provider_(content_provider) {} + : content_provider_(std::move(content_provider)) {} bool operator()(size_t offset, size_t, DataSink &sink) { return content_provider_(offset, sink); @@ -3744,8 +3926,189 @@ private: ContentProviderWithoutLength content_provider_; }; +// NOTE: https://www.rfc-editor.org/rfc/rfc9110#section-5 +namespace fields { + +bool is_token_char(char c) { + return std::isalnum(c) || c == '!' || c == '#' || c == '$' || c == '%' || + c == '&' || c == '\'' || c == '*' || c == '+' || c == '-' || + c == '.' || c == '^' || c == '_' || c == '`' || c == '|' || c == '~'; +} + +bool is_token(const std::string &s) { + if (s.empty()) { return false; } + for (auto c : s) { + if (!is_token_char(c)) { return false; } + } + return true; +} + +bool is_field_name(const std::string &s) { return is_token(s); } + +bool is_vchar(char c) { return c >= 33 && c <= 126; } + +bool is_obs_text(char c) { return 128 <= static_cast(c); } + +bool is_field_vchar(char c) { return is_vchar(c) || is_obs_text(c); } + +bool is_field_content(const std::string &s) { + if (s.empty()) { return true; } + + if (s.size() == 1) { + return is_field_vchar(s[0]); + } else if (s.size() == 2) { + return is_field_vchar(s[0]) && is_field_vchar(s[1]); + } else { + size_t i = 0; + + if (!is_field_vchar(s[i])) { return false; } + i++; + + while (i < s.size() - 1) { + auto c = s[i++]; + if (c == ' ' || c == '\t' || is_field_vchar(c)) { + } else { + return false; + } + } + + return is_field_vchar(s[i]); + } +} + +bool is_field_value(const std::string &s) { return is_field_content(s); } + +} // namespace fields + } // namespace detail +const char *status_message(int status) { + switch (status) { + case StatusCode::Continue_100: return "Continue"; + case StatusCode::SwitchingProtocol_101: return "Switching Protocol"; + case StatusCode::Processing_102: return "Processing"; + case StatusCode::EarlyHints_103: return "Early Hints"; + case StatusCode::OK_200: return "OK"; + case StatusCode::Created_201: return "Created"; + case StatusCode::Accepted_202: return "Accepted"; + case StatusCode::NonAuthoritativeInformation_203: + return "Non-Authoritative Information"; + case StatusCode::NoContent_204: return "No Content"; + case StatusCode::ResetContent_205: return "Reset Content"; + case StatusCode::PartialContent_206: return "Partial Content"; + case StatusCode::MultiStatus_207: return "Multi-Status"; + case StatusCode::AlreadyReported_208: return "Already Reported"; + case StatusCode::IMUsed_226: return "IM Used"; + case StatusCode::MultipleChoices_300: return "Multiple Choices"; + case StatusCode::MovedPermanently_301: return "Moved Permanently"; + case StatusCode::Found_302: return "Found"; + case StatusCode::SeeOther_303: return "See Other"; + case StatusCode::NotModified_304: return "Not Modified"; + case StatusCode::UseProxy_305: return "Use Proxy"; + case StatusCode::unused_306: return "unused"; + case StatusCode::TemporaryRedirect_307: return "Temporary Redirect"; + case StatusCode::PermanentRedirect_308: return "Permanent Redirect"; + case StatusCode::BadRequest_400: return "Bad Request"; + case StatusCode::Unauthorized_401: return "Unauthorized"; + case StatusCode::PaymentRequired_402: return "Payment Required"; + case StatusCode::Forbidden_403: return "Forbidden"; + case StatusCode::NotFound_404: return "Not Found"; + case StatusCode::MethodNotAllowed_405: return "Method Not Allowed"; + case StatusCode::NotAcceptable_406: return "Not Acceptable"; + case StatusCode::ProxyAuthenticationRequired_407: + return "Proxy Authentication Required"; + case StatusCode::RequestTimeout_408: return "Request Timeout"; + case StatusCode::Conflict_409: return "Conflict"; + case StatusCode::Gone_410: return "Gone"; + case StatusCode::LengthRequired_411: return "Length Required"; + case StatusCode::PreconditionFailed_412: return "Precondition Failed"; + case StatusCode::PayloadTooLarge_413: return "Payload Too Large"; + case StatusCode::UriTooLong_414: return "URI Too Long"; + case StatusCode::UnsupportedMediaType_415: return "Unsupported Media Type"; + case StatusCode::RangeNotSatisfiable_416: return "Range Not Satisfiable"; + case StatusCode::ExpectationFailed_417: return "Expectation Failed"; + case StatusCode::ImATeapot_418: return "I'm a teapot"; + case StatusCode::MisdirectedRequest_421: return "Misdirected Request"; + case StatusCode::UnprocessableContent_422: return "Unprocessable Content"; + case StatusCode::Locked_423: return "Locked"; + case StatusCode::FailedDependency_424: return "Failed Dependency"; + case StatusCode::TooEarly_425: return "Too Early"; + case StatusCode::UpgradeRequired_426: return "Upgrade Required"; + case StatusCode::PreconditionRequired_428: return "Precondition Required"; + case StatusCode::TooManyRequests_429: return "Too Many Requests"; + case StatusCode::RequestHeaderFieldsTooLarge_431: + return "Request Header Fields Too Large"; + case StatusCode::UnavailableForLegalReasons_451: + return "Unavailable For Legal Reasons"; + case StatusCode::NotImplemented_501: return "Not Implemented"; + case StatusCode::BadGateway_502: return "Bad Gateway"; + case StatusCode::ServiceUnavailable_503: return "Service Unavailable"; + case StatusCode::GatewayTimeout_504: return "Gateway Timeout"; + case StatusCode::HttpVersionNotSupported_505: + return "HTTP Version Not Supported"; + case StatusCode::VariantAlsoNegotiates_506: return "Variant Also Negotiates"; + case StatusCode::InsufficientStorage_507: return "Insufficient Storage"; + case StatusCode::LoopDetected_508: return "Loop Detected"; + case StatusCode::NotExtended_510: return "Not Extended"; + case StatusCode::NetworkAuthenticationRequired_511: + return "Network Authentication Required"; + + default: + case StatusCode::InternalServerError_500: return "Internal Server Error"; + } +} + +std::string to_string(const Error error) { + switch (error) { + case Error::Success: return "Success (no error)"; + case Error::Unknown: return "Unknown"; + case Error::Connection: return "Could not establish connection"; + case Error::BindIPAddress: return "Failed to bind IP address"; + case Error::Read: return "Failed to read connection"; + case Error::Write: return "Failed to write connection"; + case Error::ExceedRedirectCount: return "Maximum redirect count exceeded"; + case Error::Canceled: return "Connection handling canceled"; + case Error::SSLConnection: return "SSL connection failed"; + case Error::SSLLoadingCerts: return "SSL certificate loading failed"; + case Error::SSLServerVerification: return "SSL server verification failed"; + case Error::SSLServerHostnameVerification: + return "SSL server hostname verification failed"; + case Error::UnsupportedMultipartBoundaryChars: + return "Unsupported HTTP multipart boundary characters"; + case Error::Compression: return "Compression failed"; + case Error::ConnectionTimeout: return "Connection timed out"; + case Error::ProxyConnection: return "Proxy connection failed"; + case Error::ConnectionClosed: return "Connection closed by server"; + case Error::Timeout: return "Read timeout"; + case Error::ResourceExhaustion: return "Resource exhaustion"; + case Error::TooManyFormDataFiles: return "Too many form data files"; + case Error::ExceedMaxPayloadSize: return "Exceeded maximum payload size"; + case Error::ExceedUriMaxLength: return "Exceeded maximum URI length"; + case Error::ExceedMaxSocketDescriptorCount: + return "Exceeded maximum socket descriptor count"; + case Error::InvalidRequestLine: return "Invalid request line"; + case Error::InvalidHTTPMethod: return "Invalid HTTP method"; + case Error::InvalidHTTPVersion: return "Invalid HTTP version"; + case Error::InvalidHeaders: return "Invalid headers"; + case Error::MultipartParsing: return "Multipart parsing failed"; + case Error::OpenFile: return "Failed to open file"; + case Error::Listen: return "Failed to listen on socket"; + case Error::GetSockName: return "Failed to get socket name"; + case Error::UnsupportedAddressFamily: return "Unsupported address family"; + case Error::HTTPParsing: return "HTTP parsing failed"; + case Error::InvalidRangeHeader: return "Invalid Range header"; + default: break; + } + + return "Invalid"; +} + +std::ostream &operator<<(std::ostream &os, const Error &obj) { + os << to_string(obj); + os << " (" << static_cast::type>(obj) << ')'; + return os; +} + std::string hosted_at(const std::string &hostname) { std::vector addrs; hosted_at(hostname, addrs); @@ -3779,7 +4142,7 @@ void hosted_at(const std::string &hostname, auto dummy = -1; if (detail::get_ip_and_port(addr, sizeof(struct sockaddr_storage), ip, dummy)) { - addrs.push_back(ip); + addrs.emplace_back(std::move(ip)); } } } @@ -4319,6 +4682,67 @@ ssize_t Stream::write(const std::string &s) { return write(s.data(), s.size()); } +// BodyReader implementation +ssize_t detail::BodyReader::read(char *buf, size_t len) { + if (!stream) { + last_error = Error::Connection; + return -1; + } + if (eof) { return 0; } + + if (!chunked) { + // Content-Length based reading + if (bytes_read >= content_length) { + eof = true; + return 0; + } + + auto remaining = content_length - bytes_read; + auto to_read = (std::min)(len, remaining); + auto n = stream->read(buf, to_read); + + if (n < 0) { + last_error = stream->get_error(); + if (last_error == Error::Success) { last_error = Error::Read; } + eof = true; + return n; + } + if (n == 0) { + // Unexpected EOF before content_length + last_error = stream->get_error(); + if (last_error == Error::Success) { last_error = Error::Read; } + eof = true; + return 0; + } + + bytes_read += static_cast(n); + if (bytes_read >= content_length) { eof = true; } + return n; + } + + // Chunked transfer encoding: delegate to shared decoder instance. + if (!chunked_decoder) { chunked_decoder.reset(new ChunkedDecoder(*stream)); } + + size_t chunk_offset = 0; + size_t chunk_total = 0; + auto n = chunked_decoder->read_payload(buf, len, chunk_offset, chunk_total); + if (n < 0) { + last_error = stream->get_error(); + if (last_error == Error::Success) { last_error = Error::Read; } + eof = true; + return n; + } + + if (n == 0) { + // Final chunk observed. Leave trailer parsing to the caller (StreamHandle). + eof = true; + return 0; + } + + bytes_read += static_cast(n); + return n; +} + namespace detail { void calc_actual_timeout(time_t max_timeout_msec, time_t duration_msec, @@ -4395,7 +4819,10 @@ ssize_t SocketStream::read(char *ptr, size_t size) { } } - if (!wait_readable()) { return -1; } + if (!wait_readable()) { + error_ = Error::Timeout; + return -1; + } read_buff_off_ = 0; read_buff_content_size_ = 0; @@ -4404,6 +4831,11 @@ ssize_t SocketStream::read(char *ptr, size_t size) { auto n = read_socket(sock_, read_buff_.data(), read_buff_size_, CPPHTTPLIB_RECV_FLAGS); if (n <= 0) { + if (n == 0) { + error_ = Error::ConnectionClosed; + } else { + error_ = Error::Read; + } return n; } else if (n <= static_cast(size)) { memcpy(ptr, read_buff_.data(), static_cast(n)); @@ -4415,7 +4847,15 @@ ssize_t SocketStream::read(char *ptr, size_t size) { return static_cast(size); } } else { - return read_socket(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS); + auto n = read_socket(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS); + if (n <= 0) { + if (n == 0) { + error_ = Error::ConnectionClosed; + } else { + error_ = Error::Read; + } + } + return n; } } @@ -4579,19 +5019,22 @@ bool RegexMatcher::match(Request &request) const { return std::regex_match(request.path, request.matches, regex_); } -std::string make_host_and_port_string(const std::string &host, int port, - bool is_ssl) { - std::string result; - +// Enclose IPv6 address in brackets if needed +std::string prepare_host_string(const std::string &host) { // Enclose IPv6 address in brackets (but not if already enclosed) if (host.find(':') == std::string::npos || (!host.empty() && host[0] == '[')) { // IPv4, hostname, or already bracketed IPv6 - result = host; + return host; } else { // IPv6 address without brackets - result = "[" + host + "]"; + return "[" + host + "]"; } +} + +std::string make_host_and_port_string(const std::string &host, int port, + bool is_ssl) { + auto result = prepare_host_string(host); // Append port if not default if ((!is_ssl && port == 80) || (is_ssl && port == 443)) { @@ -4603,6 +5046,29 @@ std::string make_host_and_port_string(const std::string &host, int port, return result; } +// Create "host:port" string always including port number (for CONNECT method) +std::string +make_host_and_port_string_always_port(const std::string &host, int port) { + return prepare_host_string(host) + ":" + std::to_string(port); +} + +template +bool check_and_write_headers(Stream &strm, Headers &headers, + T header_writer, Error &error) { + for (const auto &h : headers) { + if (!detail::fields::is_field_name(h.first) || + !detail::fields::is_field_value(h.second)) { + error = Error::InvalidHeaders; + return false; + } + } + if (header_writer(strm, headers) <= 0) { + error = Error::Write; + return false; + } + return true; +} + } // namespace detail // HTTP server implementation @@ -4694,7 +5160,7 @@ bool Server::set_mount_point(const std::string &mount_point, if (stat.is_dir()) { std::string mnt = !mount_point.empty() ? mount_point : "/"; if (!mnt.empty() && mnt[0] == '/') { - base_dirs_.push_back({mnt, dir, std::move(headers)}); + base_dirs_.push_back({std::move(mnt), dir, std::move(headers)}); return true; } } @@ -5010,7 +5476,7 @@ bool Server::write_response_core(Stream &strm, bool close_connection, { detail::BufferStream bstrm; if (!detail::write_response_line(bstrm, res.status)) { return false; } - if (!header_writer_(bstrm, res.headers)) { return false; } + if (header_writer_(bstrm, res.headers) <= 0) { return false; } // Flush buffer auto &data = bstrm.get_buffer(); @@ -5103,7 +5569,16 @@ bool Server::read_content(Stream &strm, Request &req, Response &res) { strm, req, res, // Regular [&](const char *buf, size_t n) { - if (req.body.size() + n > req.body.max_size()) { return false; } + // Prevent arithmetic overflow when checking sizes. + // Avoid computing (req.body.size() + n) directly because + // adding two unsigned `size_t` values can wrap around and + // produce a small result instead of indicating overflow. + // Instead, check using subtraction: ensure `n` does not + // exceed the remaining capacity `max_size() - size()`. + if (req.body.size() >= req.body.max_size() || + n > req.body.max_size() - req.body.size()) { + return false; + } req.body.append(buf, n); return true; }, @@ -5186,10 +5661,39 @@ bool Server::read_content_core( // RFC 7230 Section 3.3.3: If this is a request message and none of the above // are true (no Transfer-Encoding and no Content-Length), then the message // body length is zero (no message body is present). + // + // For non-SSL builds, peek into the socket to detect clients that send a + // body without a Content-Length header (raw HTTP over TCP). If there is + // pending data that exceeds the configured payload limit, treat this as an + // oversized request and fail early (causing connection close). For SSL + // builds we cannot reliably peek the decrypted application bytes, so keep + // the original behaviour. +#if !defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(_WIN32) + if (!req.has_header("Content-Length") && + !detail::is_chunked_transfer_encoding(req.headers)) { + socket_t s = strm.socket(); + if (s != INVALID_SOCKET) { + // Peek up to payload_max_length_ + 1 bytes. If more than + // payload_max_length_ bytes are pending, reject the request. + size_t to_peek = + (payload_max_length_ > 0) + ? (std::min)(payload_max_length_ + 1, static_cast(4096)) + : 1; + std::vector peekbuf(to_peek); + ssize_t n = ::recv(s, peekbuf.data(), to_peek, MSG_PEEK); + if (n > 0 && static_cast(n) > payload_max_length_) { + // Indicate failure so connection will be closed. + return false; + } + } + return true; + } +#else if (!req.has_header("Content-Length") && !detail::is_chunked_transfer_encoding(req.headers)) { return true; } +#endif if (!detail::read_content(strm, req, payload_max_length_, res.status, nullptr, out, true)) { @@ -5207,7 +5711,7 @@ bool Server::read_content_core( return true; } -bool Server::handle_file_request(const Request &req, Response &res) { +bool Server::handle_file_request(Request &req, Response &res) { for (const auto &entry : base_dirs_) { // Prefix match if (!req.path.compare(0, entry.mount_point.size(), entry.mount_point)) { @@ -5228,6 +5732,20 @@ bool Server::handle_file_request(const Request &req, Response &res) { res.set_header(kv.first, kv.second); } + auto etag = detail::compute_etag(stat); + if (!etag.empty()) { res.set_header("ETag", etag); } + + auto mtime = stat.mtime(); + + auto last_modified = detail::file_mtime_to_http_date(mtime); + if (!last_modified.empty()) { + res.set_header("Last-Modified", last_modified); + } + + if (check_if_not_modified(req, res, etag, mtime)) { return true; } + + check_if_range(req, etag, mtime); + auto mm = std::make_shared(path.c_str()); if (!mm->is_open()) { output_error_log(Error::OpenFile, &req); @@ -5257,6 +5775,79 @@ bool Server::handle_file_request(const Request &req, Response &res) { return false; } +bool Server::check_if_not_modified(const Request &req, Response &res, + const std::string &etag, + time_t mtime) const { + // Handle conditional GET: + // 1. If-None-Match takes precedence (RFC 9110 Section 13.1.2) + // 2. If-Modified-Since is checked only when If-None-Match is absent + if (req.has_header("If-None-Match")) { + if (!etag.empty()) { + auto val = req.get_header_value("If-None-Match"); + + // NOTE: We use exact string matching here. This works correctly + // because our server always generates weak ETags (W/"..."), and + // clients typically send back the same ETag they received. + // RFC 9110 Section 8.8.3.2 allows weak comparison for + // If-None-Match, where W/"x" and "x" would match, but this + // simplified implementation requires exact matches. + auto ret = detail::split_find(val.data(), val.data() + val.size(), ',', + [&](const char *b, const char *e) { + return std::equal(b, e, "*") || + std::equal(b, e, etag.begin()); + }); + + if (ret) { + res.status = StatusCode::NotModified_304; + return true; + } + } + } else if (req.has_header("If-Modified-Since")) { + auto val = req.get_header_value("If-Modified-Since"); + auto t = detail::parse_http_date(val); + + if (t != static_cast(-1) && mtime <= t) { + res.status = StatusCode::NotModified_304; + return true; + } + } + return false; +} + +bool Server::check_if_range(Request &req, const std::string &etag, + time_t mtime) const { + // Handle If-Range for partial content requests (RFC 9110 + // Section 13.1.5). If-Range is only evaluated when Range header is + // present. If the validator matches, serve partial content; otherwise + // serve full content. + if (!req.ranges.empty() && req.has_header("If-Range")) { + auto val = req.get_header_value("If-Range"); + + auto is_valid_range = [&]() { + if (detail::is_strong_etag(val)) { + // RFC 9110 Section 13.1.5: If-Range requires strong ETag + // comparison. + return (!etag.empty() && val == etag); + } else if (detail::is_weak_etag(val)) { + // Weak ETags are not valid for If-Range (RFC 9110 Section 13.1.5) + return false; + } else { + // HTTP-date comparison + auto t = detail::parse_http_date(val); + return (t != static_cast(-1) && mtime <= t); + } + }; + + if (!is_valid_range()) { + // Validator doesn't match: ignore Range and serve full content + req.ranges.clear(); + return false; + } + } + + return true; +} + socket_t Server::create_server_socket(const std::string &host, int port, int socket_flags, @@ -5524,10 +6115,13 @@ void Server::apply_ranges(const Request &req, Response &res, res.set_header("Transfer-Encoding", "chunked"); if (type == detail::EncodingType::Gzip) { res.set_header("Content-Encoding", "gzip"); + res.set_header("Vary", "Accept-Encoding"); } else if (type == detail::EncodingType::Brotli) { res.set_header("Content-Encoding", "br"); + res.set_header("Vary", "Accept-Encoding"); } else if (type == detail::EncodingType::Zstd) { res.set_header("Content-Encoding", "zstd"); + res.set_header("Vary", "Accept-Encoding"); } } } @@ -5586,6 +6180,7 @@ void Server::apply_ranges(const Request &req, Response &res, })) { res.body.swap(compressed); res.set_header("Content-Encoding", content_encoding); + res.set_header("Vary", "Accept-Encoding"); } } } @@ -5663,6 +6258,10 @@ Server::process_request(Stream &strm, const std::string &remote_addr, Request req; req.start_time_ = std::chrono::steady_clock::now(); + req.remote_addr = remote_addr; + req.remote_port = remote_port; + req.local_addr = local_addr; + req.local_port = local_port; Response res; res.version = "HTTP/1.1"; @@ -5908,7 +6507,6 @@ ClientImpl::ClientImpl(const std::string &host, int port, const std::string &client_cert_path, const std::string &client_key_path) : host_(detail::escape_abstract_namespace_unix_domain(host)), port_(port), - host_and_port_(detail::make_host_and_port_string(host_, port, is_ssl())), client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} ClientImpl::~ClientImpl() { @@ -6007,6 +6605,26 @@ bool ClientImpl::create_and_connect_socket(Socket &socket, return true; } +bool ClientImpl::ensure_socket_connection(Socket &socket, Error &error) { + return create_and_connect_socket(socket, error); +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +bool SSLClient::ensure_socket_connection(Socket &socket, Error &error) { + if (!ClientImpl::ensure_socket_connection(socket, error)) { return false; } + + if (!proxy_host_.empty() && proxy_port_ != -1) { return true; } + + if (!initialize_ssl(socket, error)) { + shutdown_socket(socket); + close_socket(socket); + return false; + } + + return true; +} +#endif + void ClientImpl::shutdown_ssl(Socket & /*socket*/, bool /*shutdown_gracefully*/) { // If there are any requests in flight from threads other than us, then it's @@ -6119,7 +6737,7 @@ bool ClientImpl::send_(Request &req, Response &res, Error &error) { } if (!is_alive) { - if (!create_and_connect_socket(socket_, error)) { + if (!ensure_socket_connection(socket_, error)) { output_error_log(error, &req); return false; } @@ -6137,9 +6755,11 @@ bool ClientImpl::send_(Request &req, Response &res, Error &error) { } } - if (!scli.initialize_ssl(socket_, error)) { - output_error_log(error, &req); - return false; + if (!proxy_host_.empty() && proxy_port_ != -1) { + if (!scli.initialize_ssl(socket_, error)) { + output_error_log(error, &req); + return false; + } } } #endif @@ -6212,6 +6832,343 @@ Result ClientImpl::send_(Request &&req) { #endif } +void ClientImpl::prepare_default_headers(Request &r, bool for_stream, + const std::string &ct) { + (void)for_stream; + for (const auto &header : default_headers_) { + if (!r.has_header(header.first)) { r.headers.insert(header); } + } + + if (!r.has_header("Host")) { + if (address_family_ == AF_UNIX) { + r.headers.emplace("Host", "localhost"); + } else { + r.headers.emplace( + "Host", detail::make_host_and_port_string(host_, port_, is_ssl())); + } + } + + if (!r.has_header("Accept")) { r.headers.emplace("Accept", "*/*"); } + + if (!r.content_receiver) { + if (!r.has_header("Accept-Encoding")) { + std::string accept_encoding; +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + accept_encoding = "br"; +#endif +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (!accept_encoding.empty()) { accept_encoding += ", "; } + accept_encoding += "gzip, deflate"; +#endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + if (!accept_encoding.empty()) { accept_encoding += ", "; } + accept_encoding += "zstd"; +#endif + r.set_header("Accept-Encoding", accept_encoding); + } + +#ifndef CPPHTTPLIB_NO_DEFAULT_USER_AGENT + if (!r.has_header("User-Agent")) { + auto agent = std::string("cpp-httplib/") + CPPHTTPLIB_VERSION; + r.set_header("User-Agent", agent); + } +#endif + } + + if (!r.body.empty()) { + if (!ct.empty() && !r.has_header("Content-Type")) { + r.headers.emplace("Content-Type", ct); + } + if (!r.has_header("Content-Length")) { + r.headers.emplace("Content-Length", std::to_string(r.body.size())); + } + } +} + +ClientImpl::StreamHandle +ClientImpl::open_stream(const std::string &method, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, + const std::string &content_type) { + StreamHandle handle; + handle.response = detail::make_unique(); + handle.error = Error::Success; + + auto query_path = params.empty() ? path : append_query_params(path, params); + handle.connection_ = detail::make_unique(); + + { + std::lock_guard guard(socket_mutex_); + + auto is_alive = false; + if (socket_.is_open()) { + is_alive = detail::is_socket_alive(socket_.sock); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_alive && is_ssl()) { + if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { + is_alive = false; + } + } +#endif + if (!is_alive) { + shutdown_ssl(socket_, false); + shutdown_socket(socket_); + close_socket(socket_); + } + } + + if (!is_alive) { + if (!ensure_socket_connection(socket_, handle.error)) { + handle.response.reset(); + return handle; + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl()) { + auto &scli = static_cast(*this); + if (!proxy_host_.empty() && proxy_port_ != -1) { + if (!scli.initialize_ssl(socket_, handle.error)) { + handle.response.reset(); + return handle; + } + } + } +#endif + } + + transfer_socket_ownership_to_handle(handle); + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl() && handle.connection_->ssl) { + handle.socket_stream_ = detail::make_unique( + handle.connection_->sock, handle.connection_->ssl, read_timeout_sec_, + read_timeout_usec_, write_timeout_sec_, write_timeout_usec_); + } else { + handle.socket_stream_ = detail::make_unique( + handle.connection_->sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_); + } +#else + handle.socket_stream_ = detail::make_unique( + handle.connection_->sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_); +#endif + handle.stream_ = handle.socket_stream_.get(); + + Request req; + req.method = method; + req.path = query_path; + req.headers = headers; + req.body = body; + + prepare_default_headers(req, true, content_type); + + auto &strm = *handle.stream_; + if (detail::write_request_line(strm, req.method, req.path) < 0) { + handle.error = Error::Write; + handle.response.reset(); + return handle; + } + + if (!detail::check_and_write_headers(strm, req.headers, header_writer_, + handle.error)) { + handle.response.reset(); + return handle; + } + + if (!body.empty()) { + if (strm.write(body.data(), body.size()) < 0) { + handle.error = Error::Write; + handle.response.reset(); + return handle; + } + } + + if (!read_response_line(strm, req, *handle.response) || + !detail::read_headers(strm, handle.response->headers)) { + handle.error = Error::Read; + handle.response.reset(); + return handle; + } + + handle.body_reader_.stream = handle.stream_; + + auto content_length_str = handle.response->get_header_value("Content-Length"); + if (!content_length_str.empty()) { + handle.body_reader_.content_length = + static_cast(std::stoull(content_length_str)); + } + + auto transfer_encoding = + handle.response->get_header_value("Transfer-Encoding"); + handle.body_reader_.chunked = (transfer_encoding == "chunked"); + + auto content_encoding = handle.response->get_header_value("Content-Encoding"); + if (!content_encoding.empty()) { + handle.decompressor_ = detail::create_decompressor(content_encoding); + } + + return handle; +} + +ssize_t ClientImpl::StreamHandle::read(char *buf, size_t len) { + if (!is_valid() || !response) { return -1; } + + if (decompressor_) { return read_with_decompression(buf, len); } + auto n = detail::read_body_content(stream_, body_reader_, buf, len); + + if (n <= 0 && body_reader_.chunked && !trailers_parsed_ && stream_) { + trailers_parsed_ = true; + if (body_reader_.chunked_decoder) { + if (!body_reader_.chunked_decoder->parse_trailers_into( + response->trailers, response->headers)) { + return n; + } + } else { + detail::ChunkedDecoder dec(*stream_); + if (!dec.parse_trailers_into(response->trailers, response->headers)) { + return n; + } + } + } + + return n; +} + +ssize_t ClientImpl::StreamHandle::read_with_decompression(char *buf, + size_t len) { + if (decompress_offset_ < decompress_buffer_.size()) { + auto available = decompress_buffer_.size() - decompress_offset_; + auto to_copy = (std::min)(len, available); + std::memcpy(buf, decompress_buffer_.data() + decompress_offset_, to_copy); + decompress_offset_ += to_copy; + return static_cast(to_copy); + } + + decompress_buffer_.clear(); + decompress_offset_ = 0; + + constexpr size_t kDecompressionBufferSize = 8192; + char compressed_buf[kDecompressionBufferSize]; + + while (true) { + auto n = detail::read_body_content(stream_, body_reader_, compressed_buf, + sizeof(compressed_buf)); + + if (n <= 0) { return n; } + + bool decompress_ok = + decompressor_->decompress(compressed_buf, static_cast(n), + [this](const char *data, size_t data_len) { + decompress_buffer_.append(data, data_len); + return true; + }); + + if (!decompress_ok) { + body_reader_.last_error = Error::Read; + return -1; + } + + if (!decompress_buffer_.empty()) { break; } + } + + auto to_copy = (std::min)(len, decompress_buffer_.size()); + std::memcpy(buf, decompress_buffer_.data(), to_copy); + decompress_offset_ = to_copy; + return static_cast(to_copy); +} + +void ClientImpl::StreamHandle::parse_trailers_if_needed() { + if (!response || !stream_ || !body_reader_.chunked || trailers_parsed_) { + return; + } + + trailers_parsed_ = true; + + const auto bufsiz = 128; + char line_buf[bufsiz]; + detail::stream_line_reader line_reader(*stream_, line_buf, bufsiz); + + if (!line_reader.getline()) { return; } + + if (!detail::parse_trailers(line_reader, response->trailers, + response->headers)) { + return; + } +} + +// Inline method implementations for `ChunkedDecoder`. +namespace detail { + +ChunkedDecoder::ChunkedDecoder(Stream &s) : strm(s) {} + +ssize_t ChunkedDecoder::read_payload(char *buf, size_t len, + size_t &out_chunk_offset, + size_t &out_chunk_total) { + if (finished) { return 0; } + + if (chunk_remaining == 0) { + stream_line_reader lr(strm, line_buf, sizeof(line_buf)); + if (!lr.getline()) { return -1; } + + char *endptr = nullptr; + unsigned long chunk_len = std::strtoul(lr.ptr(), &endptr, 16); + if (endptr == lr.ptr()) { return -1; } + if (chunk_len == ULONG_MAX) { return -1; } + + if (chunk_len == 0) { + chunk_remaining = 0; + finished = true; + out_chunk_offset = 0; + out_chunk_total = 0; + return 0; + } + + chunk_remaining = static_cast(chunk_len); + last_chunk_total = chunk_remaining; + last_chunk_offset = 0; + } + + auto to_read = (std::min)(chunk_remaining, len); + auto n = strm.read(buf, to_read); + if (n <= 0) { return -1; } + + auto offset_before = last_chunk_offset; + last_chunk_offset += static_cast(n); + chunk_remaining -= static_cast(n); + + out_chunk_offset = offset_before; + out_chunk_total = last_chunk_total; + + if (chunk_remaining == 0) { + stream_line_reader lr(strm, line_buf, sizeof(line_buf)); + if (!lr.getline()) { return -1; } + if (std::strcmp(lr.ptr(), "\r\n") != 0) { return -1; } + } + + return n; +} + +bool ChunkedDecoder::parse_trailers_into(Headers &dest, + const Headers &src_headers) { + stream_line_reader lr(strm, line_buf, sizeof(line_buf)); + if (!lr.getline()) { return false; } + return parse_trailers(lr, dest, src_headers); +} + +} // namespace detail + +void +ClientImpl::transfer_socket_ownership_to_handle(StreamHandle &handle) { + handle.connection_->sock = socket_.sock; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + handle.connection_->ssl = socket_.ssl; + socket_.ssl = nullptr; +#endif + socket_.sock = INVALID_SOCKET; +} + bool ClientImpl::handle_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error) { @@ -6227,9 +7184,11 @@ bool ClientImpl::handle_request(Stream &strm, Request &req, if (!is_ssl() && !proxy_host_.empty() && proxy_port_ != -1) { auto req2 = req; - req2.path = "http://" + host_and_port_ + req.path; + req2.path = "http://" + + detail::make_host_and_port_string(host_, port_, false) + + req.path; ret = process_request(strm, req2, res, close_connection, error); - req = req2; + req = std::move(req2); req.path = req_save.path; } else { ret = process_request(strm, req, res, close_connection, error); @@ -6253,7 +7212,7 @@ bool ClientImpl::handle_request(Stream &strm, Request &req, } if (300 < res.status && res.status < 400 && follow_location_) { - req = req_save; + req = std::move(req_save); ret = redirect(req, res, error); } @@ -6281,7 +7240,7 @@ bool ClientImpl::handle_request(Stream &strm, Request &req, Response new_res; ret = send(new_req, new_res, error); - if (ret) { res = new_res; } + if (ret) { res = std::move(new_res); } } } } @@ -6514,42 +7473,11 @@ bool ClientImpl::write_request(Stream &strm, Request &req, } } - if (!req.has_header("Host")) { - // For Unix socket connections, use "localhost" as Host header (similar to - // curl behavior) - if (address_family_ == AF_UNIX) { - req.set_header("Host", "localhost"); - } else { - req.set_header("Host", host_and_port_); - } + std::string ct_for_defaults; + if (!req.has_header("Content-Type") && !req.body.empty()) { + ct_for_defaults = "text/plain"; } - - if (!req.has_header("Accept")) { req.set_header("Accept", "*/*"); } - - if (!req.content_receiver) { - if (!req.has_header("Accept-Encoding")) { - std::string accept_encoding; -#ifdef CPPHTTPLIB_BROTLI_SUPPORT - accept_encoding = "br"; -#endif -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - if (!accept_encoding.empty()) { accept_encoding += ", "; } - accept_encoding += "gzip, deflate"; -#endif -#ifdef CPPHTTPLIB_ZSTD_SUPPORT - if (!accept_encoding.empty()) { accept_encoding += ", "; } - accept_encoding += "zstd"; -#endif - req.set_header("Accept-Encoding", accept_encoding); - } - -#ifndef CPPHTTPLIB_NO_DEFAULT_USER_AGENT - if (!req.has_header("User-Agent")) { - auto agent = std::string("cpp-httplib/") + CPPHTTPLIB_VERSION; - req.set_header("User-Agent", agent); - } -#endif - }; + prepare_default_headers(req, false, ct_for_defaults); if (req.body.empty()) { if (req.content_provider_) { @@ -6565,15 +7493,6 @@ bool ClientImpl::write_request(Stream &strm, Request &req, req.set_header("Content-Length", "0"); } } - } else { - if (!req.has_header("Content-Type")) { - req.set_header("Content-Type", "text/plain"); - } - - if (!req.has_header("Content-Length")) { - auto length = std::to_string(req.body.size()); - req.set_header("Content-Length", length); - } } if (!basic_auth_password_.empty() || !basic_auth_username_.empty()) { @@ -6620,18 +7539,41 @@ bool ClientImpl::write_request(Stream &strm, Request &req, query_part = ""; } - // Encode path and query + // Encode path part. If the original `req.path` already contained a + // query component, preserve its raw query string (including parameter + // order) instead of reparsing and reassembling it which may reorder + // parameters due to container ordering (e.g. `Params` uses + // `std::multimap`). When there is no query in `req.path`, fall back to + // building a query from `req.params` so existing callers that pass + // `Params` continue to work. auto path_with_query = path_encode_ ? detail::encode_path(path_part) : path_part; - detail::parse_query_text(query_part, req.params); - if (!req.params.empty()) { - path_with_query = append_query_params(path_with_query, req.params); + if (!query_part.empty()) { + // Normalize the query string (decode then re-encode) while preserving + // the original parameter order. + auto normalized = detail::normalize_query_string(query_part); + if (!normalized.empty()) { path_with_query += '?' + normalized; } + + // Still populate req.params for handlers/users who read them. + detail::parse_query_text(query_part, req.params); + } else { + // No query in path; parse any query_part (empty) and append params + // from `req.params` when present (preserves prior behavior for + // callers who provide Params separately). + detail::parse_query_text(query_part, req.params); + if (!req.params.empty()) { + path_with_query = append_query_params(path_with_query, req.params); + } } // Write request line and headers detail::write_request_line(bstrm, req.method, path_with_query); - header_writer_(bstrm, req.headers); + if (!detail::check_and_write_headers(bstrm, req.headers, header_writer_, + error)) { + output_error_log(error, &req); + return false; + } // Flush buffer auto &data = bstrm.get_buffer(); @@ -8096,7 +9038,9 @@ bool SSLSocketStream::wait_writable() const { ssize_t SSLSocketStream::read(char *ptr, size_t size) { if (SSL_pending(ssl_) > 0) { - return SSL_read(ssl_, ptr, static_cast(size)); + auto ret = SSL_read(ssl_, ptr, static_cast(size)); + if (ret == 0) { error_ = Error::ConnectionClosed; } + return ret; } else if (wait_readable()) { auto ret = SSL_read(ssl_, ptr, static_cast(size)); if (ret < 0) { @@ -8121,9 +9065,12 @@ ssize_t SSLSocketStream::read(char *ptr, size_t size) { } } assert(ret < 0); + } else if (ret == 0) { + error_ = Error::ConnectionClosed; } return ret; } else { + error_ = Error::Timeout; return -1; } } @@ -8499,7 +9446,8 @@ bool SSLClient::connect_with_proxy( start_time, [&](Stream &strm) { Request req2; req2.method = "CONNECT"; - req2.path = host_and_port_; + req2.path = + detail::make_host_and_port_string_always_port(host_, port_); if (max_timeout_msec_ > 0) { req2.start_time_ = std::chrono::steady_clock::now(); } @@ -8526,7 +9474,7 @@ bool SSLClient::connect_with_proxy( close_socket(socket); // Create a new socket for the authenticated CONNECT request - if (!create_and_connect_socket(socket, error)) { + if (!ensure_socket_connection(socket, error)) { success = false; output_error_log(error, nullptr); return false; @@ -8539,7 +9487,8 @@ bool SSLClient::connect_with_proxy( start_time, [&](Stream &strm) { Request req3; req3.method = "CONNECT"; - req3.path = host_and_port_; + req3.path = detail::make_host_and_port_string_always_port( + host_, port_); req3.headers.insert(detail::make_digest_authentication_header( req3, auth, 1, detail::random_string(10), proxy_digest_auth_username_, proxy_digest_auth_password_, @@ -9424,6 +10373,13 @@ Result Client::Options(const std::string &path, const Headers &headers) { return cli_->Options(path, headers); } +ClientImpl::StreamHandle +Client::open_stream(const std::string &method, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, const std::string &content_type) { + return cli_->open_stream(method, path, params, headers, body, content_type); +} + bool Client::send(Request &req, Response &res, Error &error) { return cli_->send(req, res, error); } diff --git a/vendor/cpp-httplib/httplib.h b/vendor/cpp-httplib/httplib.h index c9bd9fd86b..43cdbc5832 100644 --- a/vendor/cpp-httplib/httplib.h +++ b/vendor/cpp-httplib/httplib.h @@ -1,15 +1,15 @@ // // httplib.h // -// Copyright (c) 2025 Yuji Hirose. All rights reserved. +// Copyright (c) 2026 Yuji Hirose. All rights reserved. // MIT License // #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.28.0" -#define CPPHTTPLIB_VERSION_NUM "0x001C00" +#define CPPHTTPLIB_VERSION "0.30.0" +#define CPPHTTPLIB_VERSION_NUM "0x001E00" /* * Platform compatibility check @@ -838,6 +838,50 @@ struct Response { std::string file_content_content_type_; }; +enum class Error { + Success = 0, + Unknown, + Connection, + BindIPAddress, + Read, + Write, + ExceedRedirectCount, + Canceled, + SSLConnection, + SSLLoadingCerts, + SSLServerVerification, + SSLServerHostnameVerification, + UnsupportedMultipartBoundaryChars, + Compression, + ConnectionTimeout, + ProxyConnection, + ConnectionClosed, + Timeout, + ResourceExhaustion, + TooManyFormDataFiles, + ExceedMaxPayloadSize, + ExceedUriMaxLength, + ExceedMaxSocketDescriptorCount, + InvalidRequestLine, + InvalidHTTPMethod, + InvalidHTTPVersion, + InvalidHeaders, + MultipartParsing, + OpenFile, + Listen, + GetSockName, + UnsupportedAddressFamily, + HTTPParsing, + InvalidRangeHeader, + + // For internal use only + SSLPeerCouldBeClosed_, +}; + +std::string to_string(Error error); + +std::ostream &operator<<(std::ostream &os, const Error &obj); + class Stream { public: virtual ~Stream() = default; @@ -856,6 +900,11 @@ public: ssize_t write(const char *ptr); ssize_t write(const std::string &s); + + Error get_error() const { return error_; } + +protected: + Error error_ = Error::Success; }; class TaskQueue { @@ -873,6 +922,7 @@ class ThreadPool final : public TaskQueue { public: explicit ThreadPool(size_t n, size_t mqr = 0) : shutdown_(false), max_queued_requests_(mqr) { + threads_.reserve(n); while (n) { threads_.emplace_back(worker(*this)); n--; @@ -961,27 +1011,21 @@ using ErrorLogger = std::function; using SocketOptions = std::function; -namespace detail { - -bool set_socket_opt_impl(socket_t sock, int level, int optname, - const void *optval, socklen_t optlen); -bool set_socket_opt(socket_t sock, int level, int optname, int opt); -bool set_socket_opt_time(socket_t sock, int level, int optname, time_t sec, - time_t usec); - -} // namespace detail - void default_socket_options(socket_t sock); const char *status_message(int status); +std::string to_string(Error error); + +std::ostream &operator<<(std::ostream &os, const Error &obj); + std::string get_bearer_token_auth(const Request &req); namespace detail { class MatcherBase { public: - MatcherBase(std::string pattern) : pattern_(pattern) {} + MatcherBase(std::string pattern) : pattern_(std::move(pattern)) {} virtual ~MatcherBase() = default; const std::string &pattern() const { return pattern_; } @@ -1051,10 +1095,9 @@ private: std::regex regex_; }; -ssize_t write_headers(Stream &strm, const Headers &headers); +int close_socket(socket_t sock); -std::string make_host_and_port_string(const std::string &host, int port, - bool is_ssl); +ssize_t write_headers(Stream &strm, const Headers &headers); } // namespace detail @@ -1206,7 +1249,11 @@ private: bool listen_internal(); bool routing(Request &req, Response &res, Stream &strm); - bool handle_file_request(const Request &req, Response &res); + bool handle_file_request(Request &req, Response &res); + bool check_if_not_modified(const Request &req, Response &res, + const std::string &etag, time_t mtime) const; + bool check_if_range(Request &req, const std::string &etag, + time_t mtime) const; bool dispatch_request(Request &req, Response &res, const Handlers &handlers) const; bool dispatch_request_for_content_reader( @@ -1290,48 +1337,6 @@ private: detail::write_headers; }; -enum class Error { - Success = 0, - Unknown, - Connection, - BindIPAddress, - Read, - Write, - ExceedRedirectCount, - Canceled, - SSLConnection, - SSLLoadingCerts, - SSLServerVerification, - SSLServerHostnameVerification, - UnsupportedMultipartBoundaryChars, - Compression, - ConnectionTimeout, - ProxyConnection, - ResourceExhaustion, - TooManyFormDataFiles, - ExceedMaxPayloadSize, - ExceedUriMaxLength, - ExceedMaxSocketDescriptorCount, - InvalidRequestLine, - InvalidHTTPMethod, - InvalidHTTPVersion, - InvalidHeaders, - MultipartParsing, - OpenFile, - Listen, - GetSockName, - UnsupportedAddressFamily, - HTTPParsing, - InvalidRangeHeader, - - // For internal use only - SSLPeerCouldBeClosed_, -}; - -std::string to_string(Error error); - -std::ostream &operator<<(std::ostream &os, const Error &obj); - class Result { public: Result() = default; @@ -1390,6 +1395,87 @@ private: #endif }; +struct ClientConnection { + socket_t sock = INVALID_SOCKET; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSL *ssl = nullptr; +#endif + + bool is_open() const { return sock != INVALID_SOCKET; } + + ClientConnection() = default; + + ~ClientConnection() { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (ssl) { + SSL_free(ssl); + ssl = nullptr; + } +#endif + if (sock != INVALID_SOCKET) { + detail::close_socket(sock); + sock = INVALID_SOCKET; + } + } + + ClientConnection(const ClientConnection &) = delete; + ClientConnection &operator=(const ClientConnection &) = delete; + + ClientConnection(ClientConnection &&other) noexcept + : sock(other.sock) +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + , + ssl(other.ssl) +#endif + { + other.sock = INVALID_SOCKET; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + other.ssl = nullptr; +#endif + } + + ClientConnection &operator=(ClientConnection &&other) noexcept { + if (this != &other) { + sock = other.sock; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + ssl = other.ssl; +#endif + other.sock = INVALID_SOCKET; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + other.ssl = nullptr; +#endif + } + return *this; + } +}; + +namespace detail { + +struct ChunkedDecoder; + +struct BodyReader { + Stream *stream = nullptr; + size_t content_length = 0; + size_t bytes_read = 0; + bool chunked = false; + bool eof = false; + std::unique_ptr chunked_decoder; + Error last_error = Error::Success; + + ssize_t read(char *buf, size_t len); + bool has_error() const { return last_error != Error::Success; } +}; + +inline ssize_t read_body_content(Stream *stream, BodyReader &br, char *buf, + size_t len) { + (void)stream; + return br.read(buf, len); +} + +class decompressor; + +} // namespace detail + class ClientImpl { public: explicit ClientImpl(const std::string &host); @@ -1404,6 +1490,43 @@ public: virtual bool is_valid() const; + struct StreamHandle { + std::unique_ptr response; + Error error = Error::Success; + + StreamHandle() = default; + StreamHandle(const StreamHandle &) = delete; + StreamHandle &operator=(const StreamHandle &) = delete; + StreamHandle(StreamHandle &&) = default; + StreamHandle &operator=(StreamHandle &&) = default; + ~StreamHandle() = default; + + bool is_valid() const { + return response != nullptr && error == Error::Success; + } + + ssize_t read(char *buf, size_t len); + void parse_trailers_if_needed(); + Error get_read_error() const { return body_reader_.last_error; } + bool has_read_error() const { return body_reader_.has_error(); } + + bool trailers_parsed_ = false; + + private: + friend class ClientImpl; + + ssize_t read_with_decompression(char *buf, size_t len); + + std::unique_ptr connection_; + std::unique_ptr socket_stream_; + Stream *stream_ = nullptr; + detail::BodyReader body_reader_; + + std::unique_ptr decompressor_; + std::string decompress_buffer_; + size_t decompress_offset_ = 0; + }; + // clang-format off Result Get(const std::string &path, DownloadProgress progress = nullptr); Result Get(const std::string &path, ContentReceiver content_receiver, DownloadProgress progress = nullptr); @@ -1497,6 +1620,15 @@ public: Result Options(const std::string &path, const Headers &headers); // clang-format on + // Streaming API: Open a stream for reading response body incrementally + // Socket ownership is transferred to StreamHandle for true streaming + // Supports all HTTP methods (GET, POST, PUT, PATCH, DELETE, etc.) + StreamHandle open_stream(const std::string &method, const std::string &path, + const Params ¶ms = {}, + const Headers &headers = {}, + const std::string &body = {}, + const std::string &content_type = {}); + bool send(Request &req, Response &res, Error &error); Result send(const Request &req); @@ -1592,6 +1724,7 @@ protected: }; virtual bool create_and_connect_socket(Socket &socket, Error &error); + virtual bool ensure_socket_connection(Socket &socket, Error &error); // All of: // shutdown_ssl @@ -1618,7 +1751,6 @@ protected: // Socket endpoint information const std::string host_; const int port_; - const std::string host_and_port_; // Current open socket Socket socket_; @@ -1717,6 +1849,8 @@ private: Response &res) const; bool write_request(Stream &strm, Request &req, bool close_connection, Error &error); + void prepare_default_headers(Request &r, bool for_stream, + const std::string &ct); bool redirect(Request &req, Response &res, Error &error); bool create_redirect_client(const std::string &scheme, const std::string &host, int port, Request &req, @@ -1747,6 +1881,8 @@ private: std::chrono::time_point start_time, std::function callback); virtual bool is_ssl() const; + + void transfer_socket_ownership_to_handle(StreamHandle &handle); }; class Client { @@ -1865,6 +2001,16 @@ public: Result Options(const std::string &path, const Headers &headers); // clang-format on + // Streaming API: Open a stream for reading response body incrementally + // Socket ownership is transferred to StreamHandle for true streaming + // Supports all HTTP methods (GET, POST, PUT, PATCH, DELETE, etc.) + ClientImpl::StreamHandle open_stream(const std::string &method, + const std::string &path, + const Params ¶ms = {}, + const Headers &headers = {}, + const std::string &body = {}, + const std::string &content_type = {}); + bool send(Request &req, Response &res, Error &error); Result send(const Request &req); @@ -2027,6 +2173,7 @@ public: private: bool create_and_connect_socket(Socket &socket, Error &error) override; + bool ensure_socket_connection(Socket &socket, Error &error) override; void shutdown_ssl(Socket &socket, bool shutdown_gracefully) override; void shutdown_ssl_impl(Socket &socket, bool shutdown_gracefully); @@ -2163,82 +2310,6 @@ inline void default_socket_options(socket_t sock) { 1); } -inline const char *status_message(int status) { - switch (status) { - case StatusCode::Continue_100: return "Continue"; - case StatusCode::SwitchingProtocol_101: return "Switching Protocol"; - case StatusCode::Processing_102: return "Processing"; - case StatusCode::EarlyHints_103: return "Early Hints"; - case StatusCode::OK_200: return "OK"; - case StatusCode::Created_201: return "Created"; - case StatusCode::Accepted_202: return "Accepted"; - case StatusCode::NonAuthoritativeInformation_203: - return "Non-Authoritative Information"; - case StatusCode::NoContent_204: return "No Content"; - case StatusCode::ResetContent_205: return "Reset Content"; - case StatusCode::PartialContent_206: return "Partial Content"; - case StatusCode::MultiStatus_207: return "Multi-Status"; - case StatusCode::AlreadyReported_208: return "Already Reported"; - case StatusCode::IMUsed_226: return "IM Used"; - case StatusCode::MultipleChoices_300: return "Multiple Choices"; - case StatusCode::MovedPermanently_301: return "Moved Permanently"; - case StatusCode::Found_302: return "Found"; - case StatusCode::SeeOther_303: return "See Other"; - case StatusCode::NotModified_304: return "Not Modified"; - case StatusCode::UseProxy_305: return "Use Proxy"; - case StatusCode::unused_306: return "unused"; - case StatusCode::TemporaryRedirect_307: return "Temporary Redirect"; - case StatusCode::PermanentRedirect_308: return "Permanent Redirect"; - case StatusCode::BadRequest_400: return "Bad Request"; - case StatusCode::Unauthorized_401: return "Unauthorized"; - case StatusCode::PaymentRequired_402: return "Payment Required"; - case StatusCode::Forbidden_403: return "Forbidden"; - case StatusCode::NotFound_404: return "Not Found"; - case StatusCode::MethodNotAllowed_405: return "Method Not Allowed"; - case StatusCode::NotAcceptable_406: return "Not Acceptable"; - case StatusCode::ProxyAuthenticationRequired_407: - return "Proxy Authentication Required"; - case StatusCode::RequestTimeout_408: return "Request Timeout"; - case StatusCode::Conflict_409: return "Conflict"; - case StatusCode::Gone_410: return "Gone"; - case StatusCode::LengthRequired_411: return "Length Required"; - case StatusCode::PreconditionFailed_412: return "Precondition Failed"; - case StatusCode::PayloadTooLarge_413: return "Payload Too Large"; - case StatusCode::UriTooLong_414: return "URI Too Long"; - case StatusCode::UnsupportedMediaType_415: return "Unsupported Media Type"; - case StatusCode::RangeNotSatisfiable_416: return "Range Not Satisfiable"; - case StatusCode::ExpectationFailed_417: return "Expectation Failed"; - case StatusCode::ImATeapot_418: return "I'm a teapot"; - case StatusCode::MisdirectedRequest_421: return "Misdirected Request"; - case StatusCode::UnprocessableContent_422: return "Unprocessable Content"; - case StatusCode::Locked_423: return "Locked"; - case StatusCode::FailedDependency_424: return "Failed Dependency"; - case StatusCode::TooEarly_425: return "Too Early"; - case StatusCode::UpgradeRequired_426: return "Upgrade Required"; - case StatusCode::PreconditionRequired_428: return "Precondition Required"; - case StatusCode::TooManyRequests_429: return "Too Many Requests"; - case StatusCode::RequestHeaderFieldsTooLarge_431: - return "Request Header Fields Too Large"; - case StatusCode::UnavailableForLegalReasons_451: - return "Unavailable For Legal Reasons"; - case StatusCode::NotImplemented_501: return "Not Implemented"; - case StatusCode::BadGateway_502: return "Bad Gateway"; - case StatusCode::ServiceUnavailable_503: return "Service Unavailable"; - case StatusCode::GatewayTimeout_504: return "Gateway Timeout"; - case StatusCode::HttpVersionNotSupported_505: - return "HTTP Version Not Supported"; - case StatusCode::VariantAlsoNegotiates_506: return "Variant Also Negotiates"; - case StatusCode::InsufficientStorage_507: return "Insufficient Storage"; - case StatusCode::LoopDetected_508: return "Loop Detected"; - case StatusCode::NotExtended_510: return "Not Extended"; - case StatusCode::NetworkAuthenticationRequired_511: - return "Network Authentication Required"; - - default: - case StatusCode::InternalServerError_500: return "Internal Server Error"; - } -} - inline std::string get_bearer_token_auth(const Request &req) { if (req.has_header("Authorization")) { constexpr auto bearer_header_prefix_len = detail::str_len("Bearer "); @@ -2272,55 +2343,6 @@ Server::set_idle_interval(const std::chrono::duration &duration) { return *this; } -inline std::string to_string(const Error error) { - switch (error) { - case Error::Success: return "Success (no error)"; - case Error::Unknown: return "Unknown"; - case Error::Connection: return "Could not establish connection"; - case Error::BindIPAddress: return "Failed to bind IP address"; - case Error::Read: return "Failed to read connection"; - case Error::Write: return "Failed to write connection"; - case Error::ExceedRedirectCount: return "Maximum redirect count exceeded"; - case Error::Canceled: return "Connection handling canceled"; - case Error::SSLConnection: return "SSL connection failed"; - case Error::SSLLoadingCerts: return "SSL certificate loading failed"; - case Error::SSLServerVerification: return "SSL server verification failed"; - case Error::SSLServerHostnameVerification: - return "SSL server hostname verification failed"; - case Error::UnsupportedMultipartBoundaryChars: - return "Unsupported HTTP multipart boundary characters"; - case Error::Compression: return "Compression failed"; - case Error::ConnectionTimeout: return "Connection timed out"; - case Error::ProxyConnection: return "Proxy connection failed"; - case Error::ResourceExhaustion: return "Resource exhaustion"; - case Error::TooManyFormDataFiles: return "Too many form data files"; - case Error::ExceedMaxPayloadSize: return "Exceeded maximum payload size"; - case Error::ExceedUriMaxLength: return "Exceeded maximum URI length"; - case Error::ExceedMaxSocketDescriptorCount: - return "Exceeded maximum socket descriptor count"; - case Error::InvalidRequestLine: return "Invalid request line"; - case Error::InvalidHTTPMethod: return "Invalid HTTP method"; - case Error::InvalidHTTPVersion: return "Invalid HTTP version"; - case Error::InvalidHeaders: return "Invalid headers"; - case Error::MultipartParsing: return "Multipart parsing failed"; - case Error::OpenFile: return "Failed to open file"; - case Error::Listen: return "Failed to listen on socket"; - case Error::GetSockName: return "Failed to get socket name"; - case Error::UnsupportedAddressFamily: return "Unsupported address family"; - case Error::HTTPParsing: return "HTTP parsing failed"; - case Error::InvalidRangeHeader: return "Invalid Range header"; - default: break; - } - - return "Invalid"; -} - -inline std::ostream &operator<<(std::ostream &os, const Error &obj) { - os << to_string(obj); - os << " (" << static_cast::type>(obj) << ')'; - return os; -} - inline size_t Result::get_request_header_value_u64(const std::string &key, size_t def, size_t id) const { @@ -2439,6 +2461,8 @@ struct FileStat { FileStat(const std::string &path); bool is_file() const; bool is_dir() const; + time_t mtime() const; + size_t size() const; private: #if defined(_WIN32) @@ -2449,6 +2473,9 @@ private: int ret_ = -1; }; +std::string make_host_and_port_string(const std::string &host, int port, + bool is_ssl); + std::string trim_copy(const std::string &s); void divide( @@ -2669,6 +2696,25 @@ private: std::string growable_buffer_; }; +bool parse_trailers(stream_line_reader &line_reader, Headers &dest, + const Headers &src_headers); + +struct ChunkedDecoder { + Stream &strm; + size_t chunk_remaining = 0; + bool finished = false; + char line_buf[64]; + size_t last_chunk_total = 0; + size_t last_chunk_offset = 0; + + explicit ChunkedDecoder(Stream &s); + + ssize_t read_payload(char *buf, size_t len, size_t &out_chunk_offset, + size_t &out_chunk_total); + + bool parse_trailers_into(Headers &dest, const Headers &src_headers); +}; + class mmap { public: mmap(const char *path); @@ -2696,59 +2742,669 @@ private: // NOTE: https://www.rfc-editor.org/rfc/rfc9110#section-5 namespace fields { -inline bool is_token_char(char c) { - return std::isalnum(c) || c == '!' || c == '#' || c == '$' || c == '%' || - c == '&' || c == '\'' || c == '*' || c == '+' || c == '-' || - c == '.' || c == '^' || c == '_' || c == '`' || c == '|' || c == '~'; -} - -inline bool is_token(const std::string &s) { - if (s.empty()) { return false; } - for (auto c : s) { - if (!is_token_char(c)) { return false; } - } - return true; -} - -inline bool is_field_name(const std::string &s) { return is_token(s); } - -inline bool is_vchar(char c) { return c >= 33 && c <= 126; } - -inline bool is_obs_text(char c) { return 128 <= static_cast(c); } - -inline bool is_field_vchar(char c) { return is_vchar(c) || is_obs_text(c); } - -inline bool is_field_content(const std::string &s) { - if (s.empty()) { return true; } - - if (s.size() == 1) { - return is_field_vchar(s[0]); - } else if (s.size() == 2) { - return is_field_vchar(s[0]) && is_field_vchar(s[1]); - } else { - size_t i = 0; - - if (!is_field_vchar(s[i])) { return false; } - i++; - - while (i < s.size() - 1) { - auto c = s[i++]; - if (c == ' ' || c == '\t' || is_field_vchar(c)) { - } else { - return false; - } - } - - return is_field_vchar(s[i]); - } -} - -inline bool is_field_value(const std::string &s) { return is_field_content(s); } +bool is_token_char(char c); +bool is_token(const std::string &s); +bool is_field_name(const std::string &s); +bool is_vchar(char c); +bool is_obs_text(char c); +bool is_field_vchar(char c); +bool is_field_content(const std::string &s); +bool is_field_value(const std::string &s); } // namespace fields } // namespace detail +namespace stream { + +class Result { +public: + Result() : chunk_size_(8192) {} + + explicit Result(ClientImpl::StreamHandle &&handle, size_t chunk_size = 8192) + : handle_(std::move(handle)), chunk_size_(chunk_size) {} + + Result(Result &&other) noexcept + : handle_(std::move(other.handle_)), buffer_(std::move(other.buffer_)), + current_size_(other.current_size_), chunk_size_(other.chunk_size_), + finished_(other.finished_) { + other.current_size_ = 0; + other.finished_ = true; + } + + Result &operator=(Result &&other) noexcept { + if (this != &other) { + handle_ = std::move(other.handle_); + buffer_ = std::move(other.buffer_); + current_size_ = other.current_size_; + chunk_size_ = other.chunk_size_; + finished_ = other.finished_; + other.current_size_ = 0; + other.finished_ = true; + } + return *this; + } + + Result(const Result &) = delete; + Result &operator=(const Result &) = delete; + + // Check if the result is valid (connection succeeded and response received) + bool is_valid() const { return handle_.is_valid(); } + explicit operator bool() const { return is_valid(); } + + // Response status code + int status() const { + return handle_.response ? handle_.response->status : -1; + } + + // Response headers + const Headers &headers() const { + static const Headers empty_headers; + return handle_.response ? handle_.response->headers : empty_headers; + } + + std::string get_header_value(const std::string &key, + const char *def = "") const { + return handle_.response ? handle_.response->get_header_value(key, def) + : def; + } + + bool has_header(const std::string &key) const { + return handle_.response ? handle_.response->has_header(key) : false; + } + + // Error information + Error error() const { return handle_.error; } + Error read_error() const { return handle_.get_read_error(); } + bool has_read_error() const { return handle_.has_read_error(); } + + // Streaming iteration API + // Call next() to read the next chunk, then access data via data()/size() + // Returns true if data was read, false when stream is exhausted + bool next() { + if (!handle_.is_valid() || finished_) { return false; } + + if (buffer_.size() < chunk_size_) { buffer_.resize(chunk_size_); } + + ssize_t n = handle_.read(&buffer_[0], chunk_size_); + if (n > 0) { + current_size_ = static_cast(n); + return true; + } + + current_size_ = 0; + finished_ = true; + return false; + } + + // Pointer to current chunk data (valid after next() returns true) + const char *data() const { return buffer_.data(); } + + // Size of current chunk (valid after next() returns true) + size_t size() const { return current_size_; } + + // Convenience method: read all remaining data into a string + std::string read_all() { + std::string result; + while (next()) { + result.append(data(), size()); + } + return result; + } + +private: + ClientImpl::StreamHandle handle_; + std::string buffer_; + size_t current_size_ = 0; + size_t chunk_size_; + bool finished_ = false; +}; + +// GET +template +inline Result Get(ClientType &cli, const std::string &path, + size_t chunk_size = 8192) { + return Result{cli.open_stream("GET", path), chunk_size}; +} + +template +inline Result Get(ClientType &cli, const std::string &path, + const Headers &headers, size_t chunk_size = 8192) { + return Result{cli.open_stream("GET", path, {}, headers), chunk_size}; +} + +template +inline Result Get(ClientType &cli, const std::string &path, + const Params ¶ms, size_t chunk_size = 8192) { + return Result{cli.open_stream("GET", path, params), chunk_size}; +} + +template +inline Result Get(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + size_t chunk_size = 8192) { + return Result{cli.open_stream("GET", path, params, headers), chunk_size}; +} + +// POST +template +inline Result Post(ClientType &cli, const std::string &path, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{cli.open_stream("POST", path, {}, {}, body, content_type), + chunk_size}; +} + +template +inline Result Post(ClientType &cli, const std::string &path, + const Headers &headers, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("POST", path, {}, headers, body, content_type), + chunk_size}; +} + +template +inline Result Post(ClientType &cli, const std::string &path, + const Params ¶ms, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("POST", path, params, {}, body, content_type), + chunk_size}; +} + +template +inline Result Post(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{ + cli.open_stream("POST", path, params, headers, body, content_type), + chunk_size}; +} + +// PUT +template +inline Result Put(ClientType &cli, const std::string &path, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{cli.open_stream("PUT", path, {}, {}, body, content_type), + chunk_size}; +} + +template +inline Result Put(ClientType &cli, const std::string &path, + const Headers &headers, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("PUT", path, {}, headers, body, content_type), + chunk_size}; +} + +template +inline Result Put(ClientType &cli, const std::string &path, + const Params ¶ms, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("PUT", path, params, {}, body, content_type), + chunk_size}; +} + +template +inline Result Put(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{ + cli.open_stream("PUT", path, params, headers, body, content_type), + chunk_size}; +} + +// PATCH +template +inline Result Patch(ClientType &cli, const std::string &path, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{cli.open_stream("PATCH", path, {}, {}, body, content_type), + chunk_size}; +} + +template +inline Result Patch(ClientType &cli, const std::string &path, + const Headers &headers, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("PATCH", path, {}, headers, body, content_type), + chunk_size}; +} + +template +inline Result Patch(ClientType &cli, const std::string &path, + const Params ¶ms, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("PATCH", path, params, {}, body, content_type), + chunk_size}; +} + +template +inline Result Patch(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{ + cli.open_stream("PATCH", path, params, headers, body, content_type), + chunk_size}; +} + +// DELETE +template +inline Result Delete(ClientType &cli, const std::string &path, + size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path), chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Headers &headers, size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path, {}, headers), chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path, {}, {}, body, content_type), + chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Headers &headers, const std::string &body, + const std::string &content_type, + size_t chunk_size = 8192) { + return Result{ + cli.open_stream("DELETE", path, {}, headers, body, content_type), + chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Params ¶ms, size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path, params), chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path, params, headers), chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Params ¶ms, const std::string &body, + const std::string &content_type, + size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path, params, {}, body, content_type), + chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{ + cli.open_stream("DELETE", path, params, headers, body, content_type), + chunk_size}; +} + +// HEAD +template +inline Result Head(ClientType &cli, const std::string &path, + size_t chunk_size = 8192) { + return Result{cli.open_stream("HEAD", path), chunk_size}; +} + +template +inline Result Head(ClientType &cli, const std::string &path, + const Headers &headers, size_t chunk_size = 8192) { + return Result{cli.open_stream("HEAD", path, {}, headers), chunk_size}; +} + +template +inline Result Head(ClientType &cli, const std::string &path, + const Params ¶ms, size_t chunk_size = 8192) { + return Result{cli.open_stream("HEAD", path, params), chunk_size}; +} + +template +inline Result Head(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + size_t chunk_size = 8192) { + return Result{cli.open_stream("HEAD", path, params, headers), chunk_size}; +} + +// OPTIONS +template +inline Result Options(ClientType &cli, const std::string &path, + size_t chunk_size = 8192) { + return Result{cli.open_stream("OPTIONS", path), chunk_size}; +} + +template +inline Result Options(ClientType &cli, const std::string &path, + const Headers &headers, size_t chunk_size = 8192) { + return Result{cli.open_stream("OPTIONS", path, {}, headers), chunk_size}; +} + +template +inline Result Options(ClientType &cli, const std::string &path, + const Params ¶ms, size_t chunk_size = 8192) { + return Result{cli.open_stream("OPTIONS", path, params), chunk_size}; +} + +template +inline Result Options(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + size_t chunk_size = 8192) { + return Result{cli.open_stream("OPTIONS", path, params, headers), chunk_size}; +} + +} // namespace stream + +namespace sse { + +struct SSEMessage { + std::string event; // Event type (default: "message") + std::string data; // Event payload + std::string id; // Event ID for Last-Event-ID header + + SSEMessage() : event("message") {} + + void clear() { + event = "message"; + data.clear(); + id.clear(); + } +}; + +class SSEClient { +public: + using MessageHandler = std::function; + using ErrorHandler = std::function; + using OpenHandler = std::function; + + SSEClient(Client &client, const std::string &path) + : client_(client), path_(path) {} + + SSEClient(Client &client, const std::string &path, const Headers &headers) + : client_(client), path_(path), headers_(headers) {} + + ~SSEClient() { stop(); } + + SSEClient(const SSEClient &) = delete; + SSEClient &operator=(const SSEClient &) = delete; + + // Event handlers + SSEClient &on_message(MessageHandler handler) { + on_message_ = std::move(handler); + return *this; + } + + SSEClient &on_event(const std::string &type, MessageHandler handler) { + event_handlers_[type] = std::move(handler); + return *this; + } + + SSEClient &on_open(OpenHandler handler) { + on_open_ = std::move(handler); + return *this; + } + + SSEClient &on_error(ErrorHandler handler) { + on_error_ = std::move(handler); + return *this; + } + + SSEClient &set_reconnect_interval(int ms) { + reconnect_interval_ms_ = ms; + return *this; + } + + SSEClient &set_max_reconnect_attempts(int n) { + max_reconnect_attempts_ = n; + return *this; + } + + // State accessors + bool is_connected() const { return connected_.load(); } + const std::string &last_event_id() const { return last_event_id_; } + + // Blocking start - runs event loop with auto-reconnect + void start() { + running_.store(true); + run_event_loop(); + } + + // Non-blocking start - runs in background thread + void start_async() { + running_.store(true); + async_thread_ = std::thread([this]() { run_event_loop(); }); + } + + // Stop the client (thread-safe) + void stop() { + running_.store(false); + client_.stop(); // Cancel any pending operations + if (async_thread_.joinable()) { async_thread_.join(); } + } + +private: + // Parse a single SSE field line + // Returns true if this line ends an event (blank line) + bool parse_sse_line(const std::string &line, SSEMessage &msg, int &retry_ms) { + // Blank line signals end of event + if (line.empty() || line == "\r") { return true; } + + // Lines starting with ':' are comments (ignored) + if (!line.empty() && line[0] == ':') { return false; } + + // Find the colon separator + auto colon_pos = line.find(':'); + if (colon_pos == std::string::npos) { + // Line with no colon is treated as field name with empty value + return false; + } + + auto field = line.substr(0, colon_pos); + std::string value; + + // Value starts after colon, skip optional single space + if (colon_pos + 1 < line.size()) { + auto value_start = colon_pos + 1; + if (line[value_start] == ' ') { value_start++; } + value = line.substr(value_start); + // Remove trailing \r if present + if (!value.empty() && value.back() == '\r') { value.pop_back(); } + } + + // Handle known fields + if (field == "event") { + msg.event = value; + } else if (field == "data") { + // Multiple data lines are concatenated with newlines + if (!msg.data.empty()) { msg.data += "\n"; } + msg.data += value; + } else if (field == "id") { + // Empty id is valid (clears the last event ID) + msg.id = value; + } else if (field == "retry") { + // Parse retry interval in milliseconds + try { + retry_ms = std::stoi(value); + } catch (...) { + // Invalid retry value, ignore + } + } + // Unknown fields are ignored per SSE spec + + return false; + } + + // Main event loop with auto-reconnect + void run_event_loop() { + auto reconnect_count = 0; + + while (running_.load()) { + // Build headers, including Last-Event-ID if we have one + auto request_headers = headers_; + if (!last_event_id_.empty()) { + request_headers.emplace("Last-Event-ID", last_event_id_); + } + + // Open streaming connection + auto result = stream::Get(client_, path_, request_headers); + + // Connection error handling + if (!result) { + connected_.store(false); + if (on_error_) { on_error_(result.error()); } + + if (!should_reconnect(reconnect_count)) { break; } + wait_for_reconnect(); + reconnect_count++; + continue; + } + + if (result.status() != 200) { + connected_.store(false); + // For certain errors, don't reconnect + if (result.status() == 204 || // No Content - server wants us to stop + result.status() == 404 || // Not Found + result.status() == 401 || // Unauthorized + result.status() == 403) { // Forbidden + if (on_error_) { on_error_(Error::Connection); } + break; + } + + if (on_error_) { on_error_(Error::Connection); } + + if (!should_reconnect(reconnect_count)) { break; } + wait_for_reconnect(); + reconnect_count++; + continue; + } + + // Connection successful + connected_.store(true); + reconnect_count = 0; + if (on_open_) { on_open_(); } + + // Event receiving loop + std::string buffer; + SSEMessage current_msg; + + while (running_.load() && result.next()) { + buffer.append(result.data(), result.size()); + + // Process complete lines in the buffer + size_t line_start = 0; + size_t newline_pos; + + while ((newline_pos = buffer.find('\n', line_start)) != + std::string::npos) { + auto line = buffer.substr(line_start, newline_pos - line_start); + line_start = newline_pos + 1; + + // Parse the line and check if event is complete + auto event_complete = + parse_sse_line(line, current_msg, reconnect_interval_ms_); + + if (event_complete && !current_msg.data.empty()) { + // Update last_event_id for reconnection + if (!current_msg.id.empty()) { last_event_id_ = current_msg.id; } + + // Dispatch event to appropriate handler + dispatch_event(current_msg); + + current_msg.clear(); + } + } + + // Keep unprocessed data in buffer + buffer.erase(0, line_start); + } + + // Connection ended + connected_.store(false); + + if (!running_.load()) { break; } + + // Check for read errors + if (result.has_read_error()) { + if (on_error_) { on_error_(result.read_error()); } + } + + if (!should_reconnect(reconnect_count)) { break; } + wait_for_reconnect(); + reconnect_count++; + } + + connected_.store(false); + } + + // Dispatch event to appropriate handler + void dispatch_event(const SSEMessage &msg) { + // Check for specific event type handler first + auto it = event_handlers_.find(msg.event); + if (it != event_handlers_.end()) { + it->second(msg); + return; + } + + // Fall back to generic message handler + if (on_message_) { on_message_(msg); } + } + + // Check if we should attempt to reconnect + bool should_reconnect(int count) const { + if (!running_.load()) { return false; } + if (max_reconnect_attempts_ == 0) { return true; } // unlimited + return count < max_reconnect_attempts_; + } + + // Wait for reconnect interval + void wait_for_reconnect() { + // Use small increments to check running_ flag frequently + auto waited = 0; + while (running_.load() && waited < reconnect_interval_ms_) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + waited += 100; + } + } + + // Client and path + Client &client_; + std::string path_; + Headers headers_; + + // Callbacks + MessageHandler on_message_; + std::map event_handlers_; + OpenHandler on_open_; + ErrorHandler on_error_; + + // Configuration + int reconnect_interval_ms_ = 3000; + int max_reconnect_attempts_ = 0; // 0 = unlimited + + // State + std::atomic running_{false}; + std::atomic connected_{false}; + std::string last_event_id_; + + // Async support + std::thread async_thread_; +}; + +} // namespace sse + } // namespace httplib From cb14b069955f12992f6c47be98a061f23ea13cf5 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 8 Jan 2026 08:16:54 -0600 Subject: [PATCH 22/27] vulkan: optimize ssm_scan (#18630) * vulkan: optimize ssm_scan * fix warp vs subgroup naming --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 20 ++-- .../ggml-vulkan/vulkan-shaders/ssm_scan.comp | 108 ++++++++---------- 2 files changed, 59 insertions(+), 69 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 4d3c085f67..7e17f4945d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -570,6 +570,7 @@ struct vk_device_struct { bool uma; bool prefer_host_memory; bool float_controls_rte_fp16; + bool subgroup_basic; bool subgroup_arithmetic; bool subgroup_shuffle; bool subgroup_ballot; @@ -4301,8 +4302,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); if (device->subgroup_arithmetic && device->subgroup_require_full_support) { - ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true); - ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true); + ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true); + ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true); } else { ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true); ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true); @@ -4638,6 +4639,8 @@ static vk_device ggml_vk_get_device(size_t idx) { } device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16; + device->subgroup_basic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBasic); device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); #ifdef __APPLE__ @@ -9870,8 +9873,9 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, std::array elements; - const int splitH = 16; - const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH); + const uint32_t d_state = src0->ne[0]; + uint32_t num_subgroups = d_state / ctx->device->subgroup_size; + const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, num_subgroups); const uint32_t num_workgroups_y = n_seq; elements = { num_workgroups_x, num_workgroups_y, 1 }; @@ -14777,11 +14781,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } - const uint32_t SPLIT_H = 16; + size_t shmem_size = d_state * sizeof(float); - size_t stateC_size = SPLIT_H * d_state * sizeof(float); + if (shmem_size > device->properties.limits.maxComputeSharedMemorySize) { + return false; + } - if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) { + if (!device->subgroup_basic) { return false; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp index 8f67be9799..c7416206db 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp @@ -1,6 +1,7 @@ #version 450 #extension GL_EXT_control_flow_attributes : require +#extension GL_KHR_shader_subgroup_basic : enable #if USE_SUBGROUP_ADD #extension GL_KHR_shader_subgroup_arithmetic : enable #endif @@ -9,7 +10,8 @@ layout(constant_id = 0) const uint D_STATE = 128; layout(constant_id = 1) const uint SUBGROUP_SIZE = 32; -layout(constant_id = 2) const uint SPLIT_H = 16; + +const uint32_t c_factor = D_STATE / SUBGROUP_SIZE; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; @@ -41,22 +43,28 @@ float softplus(float x) { } } -shared float stateC[SPLIT_H * D_STATE]; +#if !USE_SUBGROUP_ADD +shared float temp[D_STATE]; +#endif void main() { - const uint tid = gl_LocalInvocationID.x; - const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head; - const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4; - const uint seq_idx = gl_WorkGroupID.y; + const uint subgroup = gl_SubgroupID; + const uint lane = gl_SubgroupInvocationID; + const uint tid = gl_SubgroupID * SUBGROUP_SIZE + lane; + const uint subgroup_idx = gl_WorkGroupID.x * c_factor + subgroup; + + const uint head_idx = subgroup_idx / d_head; + const uint head_off = (subgroup_idx % d_head) * 4; + const uint seq_idx = gl_WorkGroupID.y; const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4; const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4; - const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4; + const uint x_base_idx = (seq_idx * nb13 + subgroup_idx * 4) / 4; const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4; const uint A_base_idx = (head_idx * nb31) / 4; const uint B_base_idx = (seq_idx * nb43 + group_off) / 4; const uint C_base_idx = (seq_idx * nb53 + group_off) / 4; - const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H; + const uint y_base_idx = seq_idx * n_tok * n_head * d_head + subgroup_idx; const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4; const uint stride_x = nb12 / 4; @@ -65,76 +73,52 @@ void main() { const uint stride_C = nb52 / 4; const uint stride_y = n_head * d_head; - float state[SPLIT_H]; - [[unroll]] for (uint j = 0; j < SPLIT_H; j++) { - state[j] = s0[s0_base_idx + j * D_STATE + tid]; + float state[c_factor]; + + [[unroll]] for (uint j = 0; j < c_factor; j++) { + state[j] = s0[s0_base_idx + SUBGROUP_SIZE * j + lane]; } + float a = A[A_base_idx]; + for (uint i = 0; i < n_tok; i++) { - const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]); + float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]); - const float dA = exp(dt_soft_plus * A[A_base_idx]); - - const float B_val = B[B_base_idx + i * stride_B + tid]; - const float C_val = C[C_base_idx + i * stride_C + tid]; - - [[unroll]] for (uint j = 0; j < SPLIT_H; j++) { - const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus; + float state_sum = 0.0f; + const float dA = exp(dt_soft_plus * a); + const float x_dt = x[x_base_idx + i * stride_x] * dt_soft_plus; + [[unroll]] for (uint j = 0; j < c_factor; j++) { + float B_val = B[B_base_idx + i * stride_B + SUBGROUP_SIZE * j + lane]; + float C_val = C[C_base_idx + i * stride_C + SUBGROUP_SIZE * j + lane]; state[j] = (state[j] * dA) + (B_val * x_dt); - - stateC[j * D_STATE + tid] = state[j] * C_val; + state_sum += state[j] * C_val; } +#if USE_SUBGROUP_ADD + state_sum = subgroupAdd(state_sum); +#else + temp[tid] = state_sum; barrier(); - [[unroll]] - for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) { - [[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) { - const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w); - if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) { - stateC[k] += stateC[k + w]; - } + [[unroll]] for (uint s = SUBGROUP_SIZE / 2; s > 0; s >>= 1) { + if (lane < s) { + temp[tid] += temp[tid + s]; } barrier(); } - - [[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) { - const uint idx = (tid % SUBGROUP_SIZE) + - D_STATE * (tid / SUBGROUP_SIZE) + - j * D_STATE * (D_STATE / SUBGROUP_SIZE); - const uint max_idx = SUBGROUP_SIZE - 1 + - D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) + - j * D_STATE * (D_STATE / SUBGROUP_SIZE); - - if (idx < SPLIT_H * D_STATE || - max_idx < SPLIT_H * D_STATE) { - float sc; -#if USE_SUBGROUP_ADD - sc = stateC[idx]; - sc = subgroupAdd(sc); -#else - [[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) { - if (idx + offset < SPLIT_H * D_STATE) { - stateC[idx] += stateC[idx + offset]; - } - barrier(); - } - if (tid % SUBGROUP_SIZE == 0) { - sc = stateC[idx]; - } + // get the value from lane 0 + state_sum = temp[subgroup * SUBGROUP_SIZE]; + barrier(); #endif - if (tid % SUBGROUP_SIZE == 0) { - const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE); - d[y_base_idx + i * stride_y + k] = sc; - } - } + if (lane == 0) { + d[y_base_idx + i * stride_y] = state_sum; } - - barrier(); } - [[unroll]] for (uint j = 0; j < SPLIT_H; j++) { - d[s_base_idx + j * D_STATE + tid] = state[j]; + // write back the state + [[unroll]] + for (int j = 0; j < c_factor; j++) { + d[s_base_idx + SUBGROUP_SIZE * j + lane] = state[j]; } } From 2524c2616458c7d6ee62fa4b4fa17e5091833544 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 8 Jan 2026 08:40:58 -0600 Subject: [PATCH 23/27] vulkan: fix push constant size for quantize_q8_1 (#18687) I added an assert to catch further mismatches, and it found several. Fix those, too. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 79 ++++++++++++++++------------ 1 file changed, 45 insertions(+), 34 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 7e17f4945d..b1a51a4365 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1505,6 +1505,11 @@ template <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) { init_fastdiv_values(p.ne01, p.ne0_1mp, p.ne0_1L); } +struct vk_quantize_q8_1_push_constants { + uint32_t ne; + uint32_t num_blocks; +}; + // Allow pre-recording command buffers struct vk_staging_memcpy { vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} @@ -3341,12 +3346,12 @@ static void ggml_vk_load_shaders(vk_device& device) { GGML_ASSERT(device->subgroup_ballot); - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) if (device->coopmat_bf16_support) { - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); } #endif @@ -3454,9 +3459,9 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) { - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); @@ -3498,9 +3503,9 @@ static void ggml_vk_load_shaders(vk_device& device) { } #endif } else { - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); @@ -3615,9 +3620,9 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) { - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); - CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); - CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); @@ -3641,9 +3646,9 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); } else { - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0); - CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0); - CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); @@ -3841,22 +3846,22 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size; const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_q8_1_f32", arr_dmmv_id_iq1_s_q8_1_f32_len[reduc], arr_dmmv_id_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_q8_1_f32", arr_dmmv_id_iq1_m_q8_1_f32_len[reduc], arr_dmmv_id_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_q8_1_f32", arr_dmmv_id_iq1_s_q8_1_f32_len[reduc], arr_dmmv_id_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_q8_1_f32", arr_dmmv_id_iq1_m_q8_1_f32_len[reduc], arr_dmmv_id_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int); } #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT } @@ -3944,9 +3949,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); if (device->subgroup_clustered && device->subgroup_require_full_support) { - ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); } else { - ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); } for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { @@ -4154,9 +4159,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); #define CREATE_GLU(name) \ if (device->float_controls_rte_fp16) { \ @@ -6100,6 +6105,7 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size()); GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT); GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size()); + GGML_ASSERT(pipeline->push_constant_size == push_constant_size(push_constants)); vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++]; vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() }; @@ -6882,7 +6888,12 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub const uint64_t max_elements = std::min(uint64_t{ctx->device->properties.limits.maxComputeWorkGroupCount[0]} * pipeline->wg_denoms[0], std::numeric_limits::max()); const uint32_t elements = std::min(ne, static_cast(max_elements)); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array{ ne, num_blocks }, { elements, 1, 1 }); + const vk_quantize_q8_1_push_constants pc = { + ne, + num_blocks, + }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, { elements, 1, 1 }); ggml_vk_sync_buffers(ctx, subctx); } From 15bff84bf56651d6f991f166a2bf0f362996f7f9 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Thu, 8 Jan 2026 08:23:39 -0800 Subject: [PATCH 24/27] ggml webgpu: initial flashattention implementation (#18610) * FlashAttention (#13) * Add inplace softmax * Move rms_norm to split row approach * Update debug for supports_op * clean up debug statements * neg f16xf32xip builds and runs, havent actually ran a model that uses neg kernel yet though * neg passes backend test * unary operators pass ggml tests * rms_norm double declaration bug atoned * abides by editor-config * removed vestigial files * fixed autoconfig * All operators (inlcluding xielu) working * removed unnecesarry checking if node->src[1] exists for unary operators * responded and dealt with PR comments * implemented REPL_Template support and removed bug in unary operators kernel * formatted embed wgsl and ggml-webgpu.cpp * Faster tensors (#8) Add fast matrix and matrix/vector multiplication. * Use map for shader replacements instead of pair of strings * Wasm (#9) * webgpu : fix build on emscripten * more debugging stuff * test-backend-ops: force single thread on wasm * fix single-thread case for init_tensor_uniform * use jspi * add pthread * test: remember to set n_thread for cpu backend * Add buffer label and enable dawn-specific toggles to turn off some checks * Intermediate state * Fast working f16/f32 vec4 * Working float fast mul mat * Clean up naming of mul_mat to match logical model, start work on q mul_mat * Setup for subgroup matrix mat mul * Basic working subgroup matrix * Working subgroup matrix tiling * Handle weirder sg matrix sizes (but still % sg matrix size) * Working start to gemv * working f16 accumulation with shared memory staging * Print out available subgroup matrix configurations * Vectorize dst stores for sg matrix shader * Gemv working scalar * Minor set_rows optimization (#4) * updated optimization, fixed errors * non vectorized version now dispatches one thread per element * Simplify * Change logic for set_rows pipelines --------- Co-authored-by: Neha Abbas Co-authored-by: Neha Abbas Co-authored-by: Reese Levine * Comment on dawn toggles * Working subgroup matrix code for (semi)generic sizes * Remove some comments * Cleanup code * Update dawn version and move to portable subgroup size * Try to fix new dawn release * Update subgroup size comment * Only check for subgroup matrix configs if they are supported * Add toggles for subgroup matrix/f16 support on nvidia+vulkan * Make row/col naming consistent * Refactor shared memory loading * Move sg matrix stores to correct file * Working q4_0 * Formatting * Work with emscripten builds * Fix test-backend-ops emscripten for f16/quantized types * Use emscripten memory64 to support get_memory * Add build flags and try ci --------- Co-authored-by: Xuan Son Nguyen * Remove extra whitespace * Move wasm single-thread logic out of test-backend-ops for cpu backend * Disable multiple threads for emscripten single-thread builds in ggml_graph_plan * Refactored pipelines and workgroup calculations (#10) * refactored pipelines * refactored workgroup calculation * removed commented out block of prior maps * Clean up ceiling division pattern --------- Co-authored-by: Neha Abbas Co-authored-by: Reese Levine * Start work on flash attention * Shader structure set up (many bugs still) * debugging * Working first test * Working with head grouping, head sizes to 128, logit softcap, mask/sinks enabled, f32 * Generalize softmax to work with multiple subgroups, f16 accumulation, mask shared memory tiling * Start work on integrating pre-wgsl * Separate structs/initial shader compilation library into separate files * Work on compilation choices for flashattention * Work on subgroup matrix/tile size portability * subgroup size agnostic online softmax * Cleanups, quantization types * more cleanup * fix wasm build * Refactor flashattention to increase parallelism, use direct loads for KV in somce cases * Checkpoint * formatting * Update to account for default kv cache padding * formatting shader * Add workflow for ggml-ci webgpu * Try passing absolute path to dawn in ggml-ci * Avoid error on device destruction, add todos for proper cleanup * Fix unused warning * Forgot one parameter unused * Move some flashattn computation to f32 for correctness --- .github/workflows/build.yml | 44 +- ci/run.sh | 15 +- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 169 ++++ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 288 ++++++- ggml/src/ggml-webgpu/pre_wgsl.hpp | 778 ++++++++++++++++++ .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 591 +++++++++++++ 6 files changed, 1838 insertions(+), 47 deletions(-) create mode 100644 ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp create mode 100644 ggml/src/ggml-webgpu/pre_wgsl.hpp create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 85601b3712..446a3750d7 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -152,13 +152,13 @@ jobs: DAWN_VERSION="v2.0.0" DAWN_OWNER="reeselevine" DAWN_REPO="dawn" - DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.zip" - echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" + DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release" + echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip" curl -L -o artifact.zip \ - "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" + "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip" mkdir dawn unzip artifact.zip - tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.tar.gz -C dawn --strip-components=1 + tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1 - name: Build id: cmake_build @@ -532,13 +532,13 @@ jobs: DAWN_VERSION="v2.0.0" DAWN_OWNER="reeselevine" DAWN_REPO="dawn" - DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.zip" - echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" + DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release" + echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip" curl -L -o artifact.zip \ - "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" + "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip" mkdir dawn unzip artifact.zip - tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.tar.gz -C dawn --strip-components=1 + tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1 - name: Build id: cmake_build @@ -1704,6 +1704,34 @@ jobs: run: | GG_BUILD_METAL=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp + ggml-ci-mac-webgpu: + runs-on: [self-hosted, macOS, ARM64] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Dawn Dependency + id: dawn-depends + run: | + DAWN_VERSION="v2.0.0" + DAWN_OWNER="reeselevine" + DAWN_REPO="dawn" + DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release" + echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip" + curl -L -o artifact.zip \ + "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip" + mkdir dawn + unzip artifact.zip + tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1 + + - name: Test + id: ggml-ci + run: | + GG_BUILD_WEBGPU=1 GG_BUILD_WEBGPU_DAWN_PREFIX="$GITHUB_WORKSPACE/dawn" \ + bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp + ggml-ci-mac-vulkan: runs-on: [self-hosted, macOS, ARM64] diff --git a/ci/run.sh b/ci/run.sh index 5c2d325a56..3deebd5dd3 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -105,7 +105,20 @@ if [ ! -z ${GG_BUILD_VULKAN} ]; then fi if [ ! -z ${GG_BUILD_WEBGPU} ]; then - CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_WEBGPU=1" + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_WEBGPU=1 -DGGML_METAL=OFF -DGGML_BLAS=OFF" + + if [ ! -z "${GG_BUILD_WEBGPU_DAWN_PREFIX}" ]; then + if [ -z "${CMAKE_PREFIX_PATH}" ]; then + export CMAKE_PREFIX_PATH="${GG_BUILD_WEBGPU_DAWN_PREFIX}" + else + export CMAKE_PREFIX_PATH="${GG_BUILD_WEBGPU_DAWN_PREFIX}:${CMAKE_PREFIX_PATH}" + fi + fi + + # For some systems, Dawn_DIR needs to be set explicitly, e.g., the lib64 path + if [ ! -z "${GG_BUILD_WEBGPU_DAWN_DIR}" ]; then + CMAKE_EXTRA="${CMAKE_EXTRA} -DDawn_DIR=${GG_BUILD_WEBGPU_DAWN_DIR}" + fi fi if [ ! -z ${GG_BUILD_MUSA} ]; then diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp new file mode 100644 index 0000000000..7fdb4c8c8d --- /dev/null +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -0,0 +1,169 @@ +#ifndef GGML_WEBGPU_SHADER_LIB_HPP +#define GGML_WEBGPU_SHADER_LIB_HPP + +#include "ggml.h" +#include "pre_wgsl.hpp" + +#include +#include + +#define GGML_WEBGPU_F16_SIZE_BYTES 2 +#define GGML_WEBGPU_F32_SIZE_BYTES 4 +#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u +#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u +// Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing. +#define GGML_WEBGPU_KV_SEQ_PAD 256u + +struct ggml_webgpu_flash_attn_shader_lib_context { + ggml_type kv_type; + uint32_t head_dim_qk; + uint32_t head_dim_v; + bool kv_direct; + bool has_mask; + bool has_sinks; + bool uses_logit_softcap; + uint32_t sg_mat_m; + uint32_t sg_mat_n; + uint32_t sg_mat_k; + size_t wg_mem_limit_bytes; + uint32_t max_subgroup_size; +}; + +struct ggml_webgpu_flash_attn_shader_decisions { + uint32_t q_tile = 0; + uint32_t kv_tile = 0; + uint32_t wg_size = 0; +}; + +struct ggml_webgpu_processed_shader { + std::string wgsl; + std::string variant; + ggml_webgpu_flash_attn_shader_decisions decisions; +}; + +// This is exposed because it's necessary in supports_op +inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, + uint32_t kv_tile, + uint32_t head_dim_qk, + uint32_t head_dim_v, + bool has_mask, + bool kv_direct) { + const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v); + size_t f16_elems = 0; + size_t f32_elems = 0; + f16_elems += q_tile * head_dim_qk; // q_shmem + if (!kv_direct) { + f16_elems += kv_tile * max_head_dim; // kv_shmem + } + f16_elems += q_tile * head_dim_v; // o_shmem + if (has_mask) { + f16_elems += q_tile * kv_tile; // mask_shmem + } + f16_elems += q_tile * kv_tile; // inter_shmem + f32_elems += q_tile; // row_max_shmem + f32_elems += q_tile; // exp_sum_shmem + return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; +} + +static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) { + const size_t limit_bytes = context.wg_mem_limit_bytes; + const size_t q_tile = context.sg_mat_m; + const size_t base_q_bytes = (context.head_dim_qk + context.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; + size_t bytes_per_kv = 0; + if (!context.kv_direct) { + bytes_per_kv += std::max(context.head_dim_qk, context.head_dim_v); + } + if (context.has_mask) { + bytes_per_kv += q_tile; + } + bytes_per_kv += q_tile; + bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; + const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; + return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; +} + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_flash_attn_shader_lib_context & context) { + std::vector defines; + std::string variant = "flash_attn"; + + switch (context.kv_type) { + case GGML_TYPE_F32: + defines.push_back("KV_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("KV_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("KV_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("KV_Q8_0"); + break; + default: + GGML_ABORT("Unsupported KV type for flash attention shader"); + } + variant += std::string("_") + ggml_type_name(context.kv_type); + + if (context.has_mask) { + defines.push_back("MASK"); + variant += "_mask"; + } + if (context.has_sinks) { + defines.push_back("SINKS"); + variant += "_sinks"; + } + if (context.uses_logit_softcap) { + defines.push_back("LOGIT_SOFTCAP"); + variant += "_lgsc"; + } + + if (context.kv_direct) { + defines.push_back("KV_DIRECT"); + variant += "_kvdirect"; + } + + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(context.head_dim_qk); + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.head_dim_v)); + variant += std::string("_hsv") + std::to_string(context.head_dim_v); + + // For now these are not part of the variant name + defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); + defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); + defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); + + // Add chosen Q/KV tile sizes + uint32_t q_tile = context.sg_mat_m; + uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context), + context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + if (context.kv_direct) { + GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); + // Avoids having to use bounds-checks and decreasing performance for direct KV loads + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= context.sg_mat_n; + } + } + + defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); + + // workgroup size + uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + result.decisions.q_tile = q_tile; + result.decisions.kv_tile = kv_tile; + result.decisions.wg_size = wg_size; + return result; +} + +#endif // GGML_WEBGPU_SHADER_LIB_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index c7afdfb8e9..f64f94b96f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -7,7 +7,9 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" +#include "ggml-webgpu-shader-lib.hpp" #include "ggml-wgsl-shaders.hpp" +#include "pre_wgsl.hpp" #ifdef __EMSCRIPTEN__ # include @@ -30,7 +32,7 @@ #ifdef GGML_WEBGPU_DEBUG # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl -# define WEBGPU_DEBUG_BUF_ELEMS 32 +# define WEBGPU_DEBUG_BUF_ELEMS 512 #else # define WEBGPU_LOG_DEBUG(msg) ((void) 0) #endif // GGML_WEBGPU_DEBUG @@ -251,6 +253,7 @@ struct webgpu_gpu_profile_buf_pool { struct webgpu_pipeline { wgpu::ComputePipeline pipeline; std::string name; + void * context = nullptr; }; struct webgpu_command { @@ -263,6 +266,46 @@ struct webgpu_command { #endif }; +struct flash_attn_pipeline_key { + int q_type; + int kv_type; + int dst_type; + uint32_t head_dim_qk; + uint32_t head_dim_v; + bool kv_direct; + bool has_mask; + bool has_sinks; + bool uses_logit_softcap; + + bool operator==(const flash_attn_pipeline_key & other) const { + return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type && + head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && + has_mask == other.has_mask && has_sinks == other.has_sinks && + uses_logit_softcap == other.uses_logit_softcap; + } +}; + +// Same hash combine function as in boost +template inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) { + seed ^= std::hash{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + +struct flash_attn_pipeline_key_hash { + size_t operator()(const flash_attn_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.q_type); + ggml_webgpu_hash_combine(seed, key.kv_type); + ggml_webgpu_hash_combine(seed, key.dst_type); + ggml_webgpu_hash_combine(seed, key.head_dim_qk); + ggml_webgpu_hash_combine(seed, key.head_dim_v); + ggml_webgpu_hash_combine(seed, key.kv_direct); + ggml_webgpu_hash_combine(seed, key.has_mask); + ggml_webgpu_hash_combine(seed, key.has_sinks); + ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); + return seed; + } +}; + // All the base objects needed to run operations on a WebGPU device struct webgpu_context_struct { wgpu::Instance instance; @@ -271,12 +314,12 @@ struct webgpu_context_struct { wgpu::Queue queue; wgpu::Limits limits; - uint32_t subgroup_size; + uint32_t max_subgroup_size; -#ifndef __EMSCRIPTEN__ - bool supports_subgroup_matrix = false; - wgpu::SubgroupMatrixConfig subgroup_matrix_config; -#endif + bool supports_subgroup_matrix = false; + uint32_t sg_mat_m; + uint32_t sg_mat_n; + uint32_t sg_mat_k; std::recursive_mutex mutex; std::atomic_uint inflight_threads = 0; @@ -284,20 +327,24 @@ struct webgpu_context_struct { webgpu_buf_pool param_buf_pool; webgpu_buf_pool set_rows_error_buf_pool; + pre_wgsl::Preprocessor p; + std::map memset_pipelines; // variant or type index std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized std::map>> mul_mat_vec_pipelines; // src0_type, src1_type, vectorized - std::map> set_rows_pipelines; // dst_type, vectorized - std::map> get_rows_pipelines; // src_type, vectorized + std::unordered_map flash_attn_pipelines; - std::map> cpy_pipelines; // src_type, dst_type - std::map> add_pipelines; // type, inplace - std::map> sub_pipelines; // type, inplace - std::map> mul_pipelines; // type, inplace - std::map> div_pipelines; // type, inplace + std::map> set_rows_pipelines; // dst_type, vectorized + std::map> get_rows_pipelines; // src_type, vectorized + + std::map> cpy_pipelines; // src_type, dst_type + std::map> add_pipelines; // type, inplace + std::map> sub_pipelines; // type, inplace + std::map> mul_pipelines; // type, inplace + std::map> div_pipelines; // type, inplace std::map rms_norm_pipelines; // inplace std::map>> rope_pipelines; // type, ff, inplace @@ -361,8 +408,6 @@ struct ggml_backend_webgpu_buffer_context { label(std::move(lbl)) {} }; -/* End struct definitions */ - /* WebGPU object initializations */ // Process a WGSL shader string, replacing tokens of the form {{KEY}} with @@ -484,14 +529,9 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) { encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize()); wgpu::CommandBuffer commands = encoder.Finish(); ctx->queue.Submit(1, &commands); - ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize()); - const uint32_t * debug_data = (const uint32_t *) ctx->debug_host_buf.GetConstMappedRange(); - std::cout << "debug data:"; - for (size_t i = 0; i < WEBGPU_DEBUG_BUF_ELEMS; i++) { - std::cout << " " << i << ": " << debug_data[i]; - } - std::cout << "\n"; + const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange(); + std::cout << "debug[0]: " << debug_data[0] << "\n"; ctx->debug_host_buf.Unmap(); } #endif @@ -673,6 +713,7 @@ static const char * ggml_backend_webgpu_name(ggml_backend_t backend) { return ctx->name.c_str(); } +// TODO: implement proper cleanup static void ggml_backend_webgpu_free(ggml_backend_t backend) { ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context; WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")"); @@ -730,12 +771,12 @@ static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { return ctx->buffer; } -static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, ggml_tensor * t) { +static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) { size_t offset = ggml_webgpu_tensor_offset(t); return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); } -static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor * t) { +static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { size_t offset = ggml_webgpu_tensor_offset(t); return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1); } @@ -964,12 +1005,10 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, #ifndef __EMSCRIPTEN__ if (ctx->supports_subgroup_matrix) { // The total number of subgroups/workgroups needed per matrix. - uint32_t wg_m_sg_tile = - WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->subgroup_matrix_config.M; - wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile); - uint32_t wg_n_sg_tile = - WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N; - wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); + uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->sg_mat_m; + wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile); + uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->sg_mat_n; + wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); } else { #endif uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; @@ -986,6 +1025,146 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } +static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst) { + float scale = *(float *) dst->op_params; + float max_bias; + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + float logit_softcap; + memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float)); + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + float n_head_log2 = float(1u << (uint32_t) floor(log2(Q->ne[2]))); + float m0 = powf(2.0f, -(max_bias) / n_head_log2); + float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + const int has_mask = (mask != nullptr); + const int has_sinks = (sinks != nullptr); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)), + has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0, + has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) Q->ne[2], // number of heads + (uint32_t) Q->ne[1], // sequence length (Q) + (uint32_t) K->ne[1], // sequence length (K/V) + (uint32_t) (Q->nb[1] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 1 + (uint32_t) (Q->nb[2] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 2 + (uint32_t) (Q->nb[3] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 3 + (uint32_t) (K->nb[1] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 1 + (uint32_t) (K->nb[2] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 2 + (uint32_t) (K->nb[3] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 3 + (uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1 + (uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2 + (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3 + has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3 + (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA) + *(uint32_t *) &scale, // scale (possibly adjusted for logit softcap) + *(uint32_t *) &max_bias, + *(uint32_t *) &logit_softcap, + *(uint32_t *) &n_head_log2, + *(uint32_t *) &m0, + *(uint32_t *) &m1 + + }; + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(Q), + .offset = ggml_webgpu_tensor_align_offset(ctx, Q), + .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(K), + .offset = ggml_webgpu_tensor_align_offset(ctx, K), + .size = ggml_webgpu_tensor_binding_size(ctx, K) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(V), + .offset = ggml_webgpu_tensor_align_offset(ctx, V), + .size = ggml_webgpu_tensor_binding_size(ctx, V) } + }; + uint32_t binding_index = 3; + if (has_mask) { + entries.push_back({ .binding = binding_index++, + .buffer = ggml_webgpu_tensor_buf(mask), + .offset = ggml_webgpu_tensor_align_offset(ctx, mask), + .size = ggml_webgpu_tensor_binding_size(ctx, mask) }); + } + if (has_sinks) { + entries.push_back({ .binding = binding_index++, + .buffer = ggml_webgpu_tensor_buf(sinks), + .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), + .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); + } + entries.push_back({ .binding = binding_index++, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + + bool kv_direct = + (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + + flash_attn_pipeline_key key = { + .q_type = Q->type, + .kv_type = K->type, + .dst_type = dst->type, + .head_dim_qk = (uint32_t) Q->ne[0], + .head_dim_v = (uint32_t) V->ne[0], + .kv_direct = kv_direct, + .has_mask = static_cast(has_mask), + .has_sinks = static_cast(has_sinks), + .uses_logit_softcap = logit_softcap != 0.0f, + }; + + webgpu_pipeline pipeline; + ggml_webgpu_flash_attn_shader_decisions decisions = {}; + + auto it = ctx->flash_attn_pipelines.find(key); + if (it != ctx->flash_attn_pipelines.end()) { + pipeline = it->second; + decisions = *static_cast(pipeline.context); + } else { + std::lock_guard lock(ctx->mutex); + it = ctx->flash_attn_pipelines.find(key); + if (it != ctx->flash_attn_pipelines.end()) { + pipeline = it->second; + decisions = *static_cast(pipeline.context); + } else { + ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .kv_type = K->type, + .head_dim_qk = (uint32_t) Q->ne[0], + .head_dim_v = (uint32_t) V->ne[0], + .kv_direct = kv_direct, + .has_mask = static_cast(has_mask), + .has_sinks = static_cast(has_sinks), + .uses_logit_softcap = logit_softcap != 0.0f, + .sg_mat_m = ctx->sg_mat_m, + .sg_mat_n = ctx->sg_mat_n, + .sg_mat_k = ctx->sg_mat_k, + .wg_mem_limit_bytes = + ctx->limits.maxComputeWorkgroupStorageSize, + .max_subgroup_size = ctx->max_subgroup_size }; + + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); + pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline.context = new ggml_webgpu_flash_attn_shader_decisions(processed.decisions); + ctx->flash_attn_pipelines.emplace(key, pipeline); + decisions = processed.decisions; + } + } + + uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile); + uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); +} + static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); ggml_unary_op unary_op = ggml_get_unary_op(dst); @@ -1397,6 +1576,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, return ggml_webgpu_get_rows(ctx, src0, src1, node); case GGML_OP_MUL_MAT: return ggml_webgpu_mul_mat(ctx, src0, src1, node); + case GGML_OP_FLASH_ATTN_EXT: + return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node); case GGML_OP_ADD: { int inplace = ggml_webgpu_tensor_equal(src0, node); @@ -1466,6 +1647,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands); futures.push_back(new_futures); } + ggml_backend_webgpu_wait(ctx, futures); ctx->inflight_threads--; WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx); @@ -1808,15 +1990,15 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { #ifndef __EMSCRIPTEN__ if (webgpu_ctx->supports_subgroup_matrix) { std::map sg_matrix_repls; - sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size); + sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->max_subgroup_size); sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K); sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M); sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N); sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M); sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N); - sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.M); - sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.N); - sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.K); + sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->sg_mat_m); + sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->sg_mat_n); + sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->sg_mat_k); proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); proc_mul_mat_f32_f32_vec = @@ -2328,6 +2510,7 @@ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", constants); } +// TODO: move most initialization logic here static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { GGML_UNUSED(params); @@ -2489,6 +2672,29 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const } break; } + case GGML_OP_FLASH_ATTN_EXT: + { + if (!webgpu_ctx->supports_subgroup_matrix) { + break; + } + // Head dimensions must fit in workgroup memory with minimum tile sizes + size_t limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize; + const bool has_mask = op->src[3] != nullptr; + const bool kv_direct = src1->type == GGML_TYPE_F16 && (src0->ne[0] % webgpu_ctx->sg_mat_k) == 0 && + (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], + has_mask, kv_direct); + if (min_bytes > limit_bytes) { + break; + } + + supports_op = src0->type == GGML_TYPE_F32 && + (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || + src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) && + src2->type == src1->type && op->type == GGML_TYPE_F32; + break; + } case GGML_OP_RMS_NORM: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; break; @@ -2606,6 +2812,7 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) { } // TODO: Does this need to be thread safe? Is it only called once? +// TODO: move most logic to device_init function so backend can be freed/initialized properly // Only one device is supported for now static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) { GGML_ASSERT(index == 0); @@ -2665,7 +2872,9 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) && config.componentType == wgpu::SubgroupMatrixComponentType::F16 && config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { - ctx->subgroup_matrix_config = config; + ctx->sg_mat_m = config.M; + ctx->sg_mat_n = config.N; + ctx->sg_mat_k = config.K; valid_subgroup_matrix_config = true; break; } @@ -2676,7 +2885,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t #endif // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. - ctx->subgroup_size = info.subgroupMaxSize; + ctx->max_subgroup_size = info.subgroupMaxSize; // Initialize device std::vector required_features = { wgpu::FeatureName::ShaderF16 }; @@ -2701,8 +2910,11 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t wgpu::CallbackMode::AllowSpontaneous, [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { GGML_UNUSED(device); - GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), - std::string(message).c_str()); + GGML_UNUSED(reason); + GGML_UNUSED(message); + //TODO: uncomment once proper free logic is in place + //GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), + //std::string(message).c_str()); }); dev_desc.SetUncapturedErrorCallback( [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { diff --git a/ggml/src/ggml-webgpu/pre_wgsl.hpp b/ggml/src/ggml-webgpu/pre_wgsl.hpp new file mode 100644 index 0000000000..4d4359463c --- /dev/null +++ b/ggml/src/ggml-webgpu/pre_wgsl.hpp @@ -0,0 +1,778 @@ +#ifndef PRE_WGSL_HPP +#define PRE_WGSL_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace pre_wgsl { + +//============================================================== +// Options +//============================================================== +struct Options { + std::string include_path = "."; + std::vector macros; +}; + +//============================================================== +// Utility: trim +//============================================================== +static std::string trim(const std::string & s) { + size_t a = 0; + while (a < s.size() && std::isspace((unsigned char) s[a])) { + a++; + } + size_t b = s.size(); + while (b > a && std::isspace((unsigned char) s[b - 1])) { + b--; + } + return s.substr(a, b - a); +} + +static std::string trim_value(std::istream & is) { + std::string str; + std::getline(is, str); + return trim(str); +} + +static bool isIdentChar(char c) { + return std::isalnum(static_cast(c)) || c == '_'; +} + +static std::string expandMacrosRecursiveInternal(const std::string & line, + const std::unordered_map & macros, + std::unordered_set & visiting); + +static std::string expandMacroValue(const std::string & name, + const std::unordered_map & macros, + std::unordered_set & visiting) { + if (visiting.count(name)) { + throw std::runtime_error("Recursive macro: " + name); + } + visiting.insert(name); + + auto it = macros.find(name); + if (it == macros.end()) { + visiting.erase(name); + return name; + } + + const std::string & value = it->second; + if (value.empty()) { + visiting.erase(name); + return ""; + } + + std::string expanded = expandMacrosRecursiveInternal(value, macros, visiting); + visiting.erase(name); + return expanded; +} + +static std::string expandMacrosRecursiveInternal(const std::string & line, + const std::unordered_map & macros, + std::unordered_set & visiting) { + std::string result; + result.reserve(line.size()); + + size_t i = 0; + while (i < line.size()) { + if (isIdentChar(line[i])) { + size_t start = i; + while (i < line.size() && isIdentChar(line[i])) { + i++; + } + std::string token = line.substr(start, i - start); + + auto it = macros.find(token); + if (it != macros.end()) { + result += expandMacroValue(token, macros, visiting); + } else { + result += token; + } + } else { + result += line[i]; + i++; + } + } + + return result; +} + +static std::string expandMacrosRecursive(const std::string & line, + const std::unordered_map & macros) { + std::unordered_set visiting; + return expandMacrosRecursiveInternal(line, macros, visiting); +} + +//============================================================== +// Tokenizer for expressions in #if/#elif +//============================================================== +class ExprLexer { + public: + enum Kind { END, IDENT, NUMBER, OP, LPAREN, RPAREN }; + + struct Tok { + Kind kind; + std::string text; + }; + + explicit ExprLexer(std::string_view sv) : src(sv), pos(0) {} + + Tok next() { + skipWS(); + if (pos >= src.size()) { + return { END, "" }; + } + + char c = src[pos]; + + // number + if (std::isdigit((unsigned char) c)) { + size_t start = pos; + while (pos < src.size() && std::isdigit((unsigned char) src[pos])) { + pos++; + } + return { NUMBER, std::string(src.substr(start, pos - start)) }; + } + + // identifier + if (std::isalpha((unsigned char) c) || c == '_') { + size_t start = pos; + while (pos < src.size() && (std::isalnum((unsigned char) src[pos]) || src[pos] == '_')) { + pos++; + } + return { IDENT, std::string(src.substr(start, pos - start)) }; + } + + if (c == '(') { + pos++; + return { LPAREN, "(" }; + } + if (c == ')') { + pos++; + return { RPAREN, ")" }; + } + + // multi-char operators + static const char * two_ops[] = { "==", "!=", "<=", ">=", "&&", "||", "<<", ">>" }; + for (auto op : two_ops) { + if (src.substr(pos, 2) == op) { + pos += 2; + return { OP, std::string(op) }; + } + } + + // single-char operators + if (std::string("+-*/%<>!").find(c) != std::string::npos) { + pos++; + return { OP, std::string(1, c) }; + } + + // unexpected + pos++; + return { END, "" }; + } + + private: + std::string_view src; + size_t pos; + + void skipWS() { + while (pos < src.size() && std::isspace((unsigned char) src[pos])) { + pos++; + } + } +}; + +//============================================================== +// Expression Parser (recursive descent) +//============================================================== +class ExprParser { + public: + ExprParser(std::string_view expr, + const std::unordered_map & macros, + std::unordered_set & visiting) : + lex(expr), + macros(macros), + visiting(visiting) { + advance(); + } + + int parse() { return parseLogicalOr(); } + + private: + ExprLexer lex; + ExprLexer::Tok tok; + const std::unordered_map & macros; + std::unordered_set & visiting; + + void advance() { tok = lex.next(); } + + bool acceptOp(const std::string & s) { + if (tok.kind == ExprLexer::OP && tok.text == s) { + advance(); + return true; + } + return false; + } + + bool acceptKind(ExprLexer::Kind k) { + if (tok.kind == k) { + advance(); + return true; + } + return false; + } + + int parseLogicalOr() { + int v = parseLogicalAnd(); + while (acceptOp("||")) { + int rhs = parseLogicalAnd(); + v = (v || rhs); + } + return v; + } + + int parseLogicalAnd() { + int v = parseEquality(); + while (acceptOp("&&")) { + int rhs = parseEquality(); + v = (v && rhs); + } + return v; + } + + int parseEquality() { + int v = parseRelational(); + for (;;) { + if (acceptOp("==")) { + int rhs = parseRelational(); + v = (v == rhs); + } else if (acceptOp("!=")) { + int rhs = parseRelational(); + v = (v != rhs); + } else { + break; + } + } + return v; + } + + int parseRelational() { + int v = parseShift(); + for (;;) { + if (acceptOp("<")) { + int rhs = parseShift(); + v = (v < rhs); + } else if (acceptOp(">")) { + int rhs = parseShift(); + v = (v > rhs); + } else if (acceptOp("<=")) { + int rhs = parseShift(); + v = (v <= rhs); + } else if (acceptOp(">=")) { + int rhs = parseShift(); + v = (v >= rhs); + } else { + break; + } + } + return v; + } + + int parseShift() { + int v = parseAdd(); + for (;;) { + if (acceptOp("<<")) { + int rhs = parseAdd(); + v = (v << rhs); + } else if (acceptOp(">>")) { + int rhs = parseAdd(); + v = (v >> rhs); + } else { + break; + } + } + return v; + } + + int parseAdd() { + int v = parseMult(); + for (;;) { + if (acceptOp("+")) { + int rhs = parseMult(); + v = (v + rhs); + } else if (acceptOp("-")) { + int rhs = parseMult(); + v = (v - rhs); + } else { + break; + } + } + return v; + } + + int parseMult() { + int v = parseUnary(); + for (;;) { + if (acceptOp("*")) { + int rhs = parseUnary(); + v = (v * rhs); + } else if (acceptOp("/")) { + int rhs = parseUnary(); + v = (rhs == 0 ? 0 : v / rhs); + } else if (acceptOp("%")) { + int rhs = parseUnary(); + v = (rhs == 0 ? 0 : v % rhs); + } else { + break; + } + } + return v; + } + + int parseUnary() { + if (acceptOp("!")) { + return !parseUnary(); + } + if (acceptOp("-")) { + return -parseUnary(); + } + if (acceptOp("+")) { + return +parseUnary(); + } + return parsePrimary(); + } + + int parsePrimary() { + // '(' expr ')' + if (acceptKind(ExprLexer::LPAREN)) { + int v = parse(); + if (!acceptKind(ExprLexer::RPAREN)) { + throw std::runtime_error("missing ')'"); + } + return v; + } + + // number + if (tok.kind == ExprLexer::NUMBER) { + int v = std::stoi(tok.text); + advance(); + return v; + } + + // defined(identifier) + if (tok.kind == ExprLexer::IDENT && tok.text == "defined") { + advance(); + if (acceptKind(ExprLexer::LPAREN)) { + if (tok.kind != ExprLexer::IDENT) { + throw std::runtime_error("expected identifier in defined()"); + } + std::string name = tok.text; + advance(); + if (!acceptKind(ExprLexer::RPAREN)) { + throw std::runtime_error("missing ) in defined()"); + } + return macros.count(name) ? 1 : 0; + } else { + // defined NAME + if (tok.kind != ExprLexer::IDENT) { + throw std::runtime_error("expected identifier in defined NAME"); + } + std::string name = tok.text; + advance(); + return macros.count(name) ? 1 : 0; + } + } + + // identifier -> treat as integer, if defined use its value else 0 + if (tok.kind == ExprLexer::IDENT) { + std::string name = tok.text; + advance(); + auto it = macros.find(name); + if (it == macros.end()) { + return 0; + } + if (it->second.empty()) { + return 1; + } + return evalMacroExpression(name, it->second); + } + + // unexpected + return 0; + } + + int evalMacroExpression(const std::string & name, const std::string & value) { + if (visiting.count(name)) { + throw std::runtime_error("Recursive macro: " + name); + } + + visiting.insert(name); + ExprParser ep(value, macros, visiting); + int v = ep.parse(); + visiting.erase(name); + return v; + } +}; + +//============================================================== +// Preprocessor +//============================================================== +class Preprocessor { + public: + explicit Preprocessor(Options opts = {}) : opts_(std::move(opts)) { + // Treat empty include path as current directory + if (opts_.include_path.empty()) { + opts_.include_path = "."; + } + parseMacroDefinitions(opts_.macros); + } + + std::string preprocess_file(const std::string & filename, const std::vector & additional_macros = {}) { + std::unordered_map macros; + std::unordered_set predefined; + std::unordered_set include_stack; + buildMacros(additional_macros, macros, predefined); + + std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::All); + return result; + } + + std::string preprocess(const std::string & contents, const std::vector & additional_macros = {}) { + std::unordered_map macros; + std::unordered_set predefined; + std::unordered_set include_stack; + buildMacros(additional_macros, macros, predefined); + + std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::All); + return result; + } + + std::string preprocess_includes_file(const std::string & filename) { + std::unordered_map macros; + std::unordered_set predefined; + std::unordered_set include_stack; + std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::IncludesOnly); + return result; + } + + std::string preprocess_includes(const std::string & contents) { + std::unordered_map macros; + std::unordered_set predefined; + std::unordered_set include_stack; + std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::IncludesOnly); + return result; + } + + private: + Options opts_; + std::unordered_map global_macros; + + enum class DirectiveMode { All, IncludesOnly }; + + struct Cond { + bool parent_active; + bool active; + bool taken; + }; + + //---------------------------------------------------------- + // Parse macro definitions into global_macros + //---------------------------------------------------------- + void parseMacroDefinitions(const std::vector & macro_defs) { + for (const auto & def : macro_defs) { + size_t eq_pos = def.find('='); + if (eq_pos != std::string::npos) { + // Format: NAME=VALUE + std::string name = trim(def.substr(0, eq_pos)); + std::string value = trim(def.substr(eq_pos + 1)); + global_macros[name] = value; + } else { + // Format: NAME + std::string name = trim(def); + global_macros[name] = ""; + } + } + } + + //---------------------------------------------------------- + // Build combined macro map and predefined set for a preprocessing operation + //---------------------------------------------------------- + void buildMacros(const std::vector & additional_macros, + std::unordered_map & macros, + std::unordered_set & predefined) { + macros = global_macros; + predefined.clear(); + + for (const auto & [name, value] : global_macros) { + predefined.insert(name); + } + + for (const auto & def : additional_macros) { + size_t eq_pos = def.find('='); + std::string name, value; + if (eq_pos != std::string::npos) { + name = trim(def.substr(0, eq_pos)); + value = trim(def.substr(eq_pos + 1)); + } else { + name = trim(def); + value = ""; + } + + // Add to macros map (will override global if same name) + macros[name] = value; + predefined.insert(name); + } + } + + //---------------------------------------------------------- + // Helpers + //---------------------------------------------------------- + std::string loadFile(const std::string & fname) { + std::ifstream f(fname); + if (!f.is_open()) { + throw std::runtime_error("Could not open file: " + fname); + } + std::stringstream ss; + ss << f.rdbuf(); + return ss.str(); + } + + bool condActive(const std::vector & cond) const { + if (cond.empty()) { + return true; + } + return cond.back().active; + } + + //---------------------------------------------------------- + // Process a file + //---------------------------------------------------------- + std::string processFile(const std::string & name, + std::unordered_map & macros, + const std::unordered_set & predefined_macros, + std::unordered_set & include_stack, + DirectiveMode mode) { + if (include_stack.count(name)) { + throw std::runtime_error("Recursive include: " + name); + } + + include_stack.insert(name); + std::string shader_code = loadFile(name); + std::string out = processString(shader_code, macros, predefined_macros, include_stack, mode); + include_stack.erase(name); + return out; + } + + std::string processIncludeFile(const std::string & fname, + std::unordered_map & macros, + const std::unordered_set & predefined_macros, + std::unordered_set & include_stack, + DirectiveMode mode) { + std::string full_path = opts_.include_path + "/" + fname; + return processFile(full_path, macros, predefined_macros, include_stack, mode); + } + + //---------------------------------------------------------- + // Process text + //---------------------------------------------------------- + std::string processString(const std::string & shader_code, + std::unordered_map & macros, + const std::unordered_set & predefined_macros, + std::unordered_set & include_stack, + DirectiveMode mode) { + std::vector cond; // Conditional stack for this shader + std::stringstream out; + std::istringstream in(shader_code); + std::string line; + + while (std::getline(in, line)) { + std::string t = trim(line); + + if (!t.empty() && t[0] == '#') { + bool handled = handleDirective(t, out, macros, predefined_macros, cond, include_stack, mode); + if (mode == DirectiveMode::IncludesOnly && !handled) { + out << line << "\n"; + } + } else { + if (mode == DirectiveMode::IncludesOnly) { + out << line << "\n"; + } else if (condActive(cond)) { + // Expand macros in the line before outputting + std::string expanded = expandMacrosRecursive(line, macros); + out << expanded << "\n"; + } + } + } + + if (mode == DirectiveMode::All && !cond.empty()) { + throw std::runtime_error("Unclosed #if directive"); + } + + return out.str(); + } + + //---------------------------------------------------------- + // Directive handler + //---------------------------------------------------------- + bool handleDirective(const std::string & t, + std::stringstream & out, + std::unordered_map & macros, + const std::unordered_set & predefined_macros, + std::vector & cond, + std::unordered_set & include_stack, + DirectiveMode mode) { + // split into tokens + std::string body = t.substr(1); + std::istringstream iss(body); + std::string cmd; + iss >> cmd; + + if (cmd == "include") { + if (mode == DirectiveMode::All && !condActive(cond)) { + return true; + } + std::string file; + iss >> file; + if (file.size() >= 2 && file.front() == '"' && file.back() == '"') { + file = file.substr(1, file.size() - 2); + } + out << processIncludeFile(file, macros, predefined_macros, include_stack, mode); + return true; + } + + if (mode == DirectiveMode::IncludesOnly) { + return false; + } + + if (cmd == "define") { + if (!condActive(cond)) { + return true; + } + std::string name; + iss >> name; + // Don't override predefined macros from options + if (predefined_macros.count(name)) { + return true; + } + std::string value = trim_value(iss); + macros[name] = value; + return true; + } + + if (cmd == "undef") { + if (!condActive(cond)) { + return true; + } + std::string name; + iss >> name; + // Don't undef predefined macros from options + if (predefined_macros.count(name)) { + return true; + } + macros.erase(name); + return true; + } + + if (cmd == "ifdef") { + std::string name; + iss >> name; + bool p = condActive(cond); + bool v = macros.count(name); + cond.push_back({ p, p && v, p && v }); + return true; + } + + if (cmd == "ifndef") { + std::string name; + iss >> name; + bool p = condActive(cond); + bool v = !macros.count(name); + cond.push_back({ p, p && v, p && v }); + return true; + } + + if (cmd == "if") { + std::string expr = trim_value(iss); + bool p = condActive(cond); + bool v = false; + if (p) { + std::unordered_set visiting; + ExprParser ep(expr, macros, visiting); + v = ep.parse() != 0; + } + cond.push_back({ p, p && v, p && v }); + return true; + } + + if (cmd == "elif") { + std::string expr = trim_value(iss); + + if (cond.empty()) { + throw std::runtime_error("#elif without #if"); + } + + Cond & c = cond.back(); + if (!c.parent_active) { + c.active = false; + return true; + } + + if (c.taken) { + c.active = false; + return true; + } + + std::unordered_set visiting; + ExprParser ep(expr, macros, visiting); + bool v = ep.parse() != 0; + c.active = v; + if (v) { + c.taken = true; + } + return true; + } + + if (cmd == "else") { + if (cond.empty()) { + throw std::runtime_error("#else without #if"); + } + + Cond & c = cond.back(); + if (!c.parent_active) { + c.active = false; + return true; + } + if (c.taken) { + c.active = false; + } else { + c.active = true; + c.taken = true; + } + return true; + } + + if (cmd == "endif") { + if (cond.empty()) { + throw std::runtime_error("#endif without #if"); + } + cond.pop_back(); + return true; + } + + // Unknown directive + throw std::runtime_error("Unknown directive: #" + cmd); + } +}; + +} // namespace pre_wgsl + +#endif // PRE_WGSL_HPP diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl new file mode 100644 index 0000000000..de7c132a62 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -0,0 +1,591 @@ +diagnostic(off, chromium.subgroup_matrix_uniformity); +diagnostic(off, subgroup_uniformity); +enable f16; +enable subgroups; +enable chromium_experimental_subgroup_matrix; + +#ifdef KV_F32 +#define KV_TYPE f32 +#else +#define KV_TYPE f16 +#endif + +// Default values +#define HEAD_DIM_QK 64 +#define HEAD_DIM_V 64 + +// The number of rows/columns/k in a subgroup matrix. MxK * KxN = MxN +// Note that the "K" here does not correspond to the K in attention's Q/K/V, it's just the common dimension. +#define SG_MAT_M 8 +#define SG_MAT_N 8 +#define SG_MAT_K 8 + +// Each workgroup processes one subgroup matrix of Q rows +#define Q_TILE SG_MAT_M +#define KV_TILE 16 +#define WG_SIZE 64 + +// Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE. +#define KV_BLOCKS (KV_TILE / SG_MAT_N) + +// Quantization constants/helpers +#define BLOCK_SIZE 32 +#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) +#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) +// number of quantized elements processed per thread +#if defined(KV_Q4_0) +#define NQ 16 +// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights +#define F16_PER_BLOCK 9 +#define WEIGHTS_PER_F16 4 +#elif defined(KV_Q8_0) +#define NQ 8 +// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights +#define F16_PER_BLOCK 17 +#define WEIGHTS_PER_F16 2 +#endif +#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) + +// Ok not to put these in a define block, compiler will remove if unused +fn get_byte(value: u32, index: u32) -> u32 { + return (value >> (index * 8)) & 0xFF; +} + +fn get_byte_i32(value: u32, index: u32) -> i32 { + return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; +} + +struct Params { + offset_q: u32, + offset_k: u32, + offset_v: u32, + offset_mask: u32, + offset_sinks: u32, + offset_dst: u32, + + // shapes of Q/K/V + n_heads: u32, + seq_len_q: u32, + seq_len_kv: u32, + + // strides (in elements) + stride_q1: u32, + stride_q2: u32, + stride_q3: u32, + stride_k1: u32, + stride_k2: u32, + stride_k3: u32, + stride_v1: u32, + stride_v2: u32, + stride_v3: u32, + stride_mask3: u32, + + // repeat factors for K/V, e.g., MHA vs. MQA vs. GQA + q_per_kv: u32, + + // softmax params + scale: f32, + max_bias: f32, + logit_softcap: f32, + n_head_log2: f32, + m0: f32, + m1: f32, +}; + +@group(0) @binding(0) var Q: array; +@group(0) @binding(1) var K: array; +@group(0) @binding(2) var V: array; + +#if defined(MASK) && defined(SINKS) +@group(0) @binding(3) var mask: array; +@group(0) @binding(4) var sinks: array; +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#elif defined(MASK) +@group(0) @binding(3) var mask: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#elif defined(SINKS) +@group(0) @binding(3) var sinks: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#else +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#endif + +@group(0) @binding(DST_BINDING) var dst: array; +@group(0) @binding(PARAMS_BINDING) var params: Params; + +// Just a very small float value. +const FLOAT_MIN: f32 = -1.0e9; + +// The number of Q rows processed per workgroup +var q_shmem: array; + +#ifndef KV_DIRECT +const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); +// we can reuse the same shmem for K and V since we only need one at a time +var kv_shmem: array; +#endif + +var o_shmem: array; // output shmem + +#ifdef MASK +// storage for mask values +var mask_shmem: array; +#endif + +// storage for output of Q*K^T scores for online softmax (S matrix from paper) +// also storage for diagonal matrix during online softmax (P matrix from paper) +// note that we reuse the same storage for both since we only need one at a time +var inter_shmem: array; + +// Storage for row max and exp sum during online softmax +var row_max_shmem: array; +var exp_sum_shmem: array; + +fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32) -> f32 { + var v = select(FLOAT_MIN, + f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale, + kv_idx < KV_TILE); +#ifdef LOGIT_SOFTCAP + v = params.logit_softcap * tanh(v); +#endif +#ifdef MASK + let mask_val = select(0.0, f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE); + let mask_term = slope * mask_val; + v += mask_term; +#endif + return v; +} + + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { + + // initialize row max for online softmax + for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) { + row_max_shmem[i] = FLOAT_MIN; + exp_sum_shmem[i] = 0.0; + } + + for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) { + o_shmem[i] = 0.0; + } + + // workgroups per head/batch + let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; + let wg_per_batch = wg_per_head * params.n_heads; + + let dst2_stride = HEAD_DIM_V * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + + // batch index + let batch_idx = wg_id.x / wg_per_batch; + let q_batch_offset = params.offset_q + batch_idx * params.stride_q3; + let k_batch_offset = params.offset_k + batch_idx * params.stride_k3; + let v_batch_offset = params.offset_v + batch_idx * params.stride_v3; + let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride; + let wg_in_batch = wg_id.x % wg_per_batch; + + // head index + let head_idx = wg_in_batch / wg_per_head; + let q_head_offset = q_batch_offset + head_idx * params.stride_q2; + let k_head_idx = head_idx / params.q_per_kv; + let v_head_idx = k_head_idx; + let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2; + let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2; + + // starting Q row for this workgroup + let wg_in_head = wg_in_batch % wg_per_head; + let q_row_start = wg_in_head * Q_TILE; + +#ifdef MASK + // mask offset + let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv; +#endif + + // note that the output is permuted, the layout is [head_dim_v, n_heads, seq_len_q, batch_size] + let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V; + + let head = f32(head_idx); + let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), params.max_bias > 0); + + // load q tile into shared memory + for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let q_row = elem_idx / HEAD_DIM_QK; + let q_col = elem_idx % HEAD_DIM_QK; + let head_q_row = q_row_start + q_row; + let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; + q_shmem[elem_idx] = f16(select( + 0.0, + Q[global_q_row_offset + q_col], + head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK)); + } + + for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { + // clear inter_shmem to ensure zero-initialized accumulators + for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + inter_shmem[elem_idx] = 0.0; + } + + // load k tile into shared memory +#if defined(KV_Q4_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let k_row = blck_idx / BLOCKS_K; + let global_k_row = kv_tile + k_row; + let block_k = blck_idx % BLOCKS_K; + let row_offset = k_row * HEAD_DIM_QK; + + if (global_k_row < params.seq_len_kv) { + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = K[base_idx]; // scale + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = K[base_idx + 1u + block_offset + j]; + let q_1 = K[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_lo; + kv_shmem[row_offset + idx + 16u] = q_hi; + } + } + } + } +#elif defined(KV_Q8_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let k_row = blck_idx / BLOCKS_K; + let global_k_row = kv_tile + k_row; + let block_k = blck_idx % BLOCKS_K; + let row_offset = k_row * HEAD_DIM_QK; + + if (global_k_row < params.seq_len_kv) { + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = K[base_idx]; // scale + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = K[base_idx + 1u + block_offset + j]; + let q_1 = K[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f16(q_byte) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_val; + } + } + } + } +#elif defined(KV_DIRECT) + // Direct global loads for KV +#else + for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let k_row = elem_idx / HEAD_DIM_QK; + let k_col = elem_idx % HEAD_DIM_QK; + let global_k_row = kv_tile + k_row; + let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; + kv_shmem[elem_idx] = f16(select( + 0.0, + K[global_k_row_offset + k_col], + global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK)); + } +#endif + + workgroupBarrier(); + + // accumulate q block * k block into registers across the entire KV tile + // TODO: this loop seems to be the current largest bottleneck + for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) { + let inter_offset = kv_block * SG_MAT_N; + var acc: subgroup_matrix_result = subgroupMatrixLoad< + subgroup_matrix_result>(&inter_shmem, inter_offset, false, KV_TILE); +#ifdef KV_DIRECT + let k_block_row = kv_tile + kv_block * SG_MAT_N; + let k_global_offset = k_head_offset + k_block_row * params.stride_k1; +#else + let k_block_offset = kv_block * SG_MAT_N * HEAD_DIM_QK; +#endif + for (var head_dim_block = 0u; head_dim_block < HEAD_DIM_QK; head_dim_block += SG_MAT_K) { + // load q submatrix from shared memory + var q_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( + &q_shmem, + head_dim_block, + false, + HEAD_DIM_QK + ); + + // load k submatrix from device or shared memory +#ifdef KV_DIRECT + var k_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + &K, + k_global_offset + head_dim_block, + true, + params.stride_k1 + ); +#else + var k_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + &kv_shmem, + k_block_offset + head_dim_block, + true, + HEAD_DIM_QK + ); +#endif + acc = subgroupMatrixMultiplyAccumulate(q_sg_mat, k_sg_mat, acc); + } + + // store acc to shared memory for softmax (S matrix from paper) + subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE); + } + +#ifdef MASK + // load mask tile into shared memory for this KV block + // TODO: optimize and skip if mask is -INF for the entire tile + for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + let mask_row = elem_idx / KV_TILE; + let mask_col = elem_idx % KV_TILE; + let global_q_row = q_row_start + mask_row; + let global_k_col = kv_tile + mask_col; + let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv; + let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col; + mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds); + } +#endif + + workgroupBarrier(); + + // online softmax + for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } + + // initialize running max for this row + var prev_max = row_max_shmem[q_tile_row]; + var final_max = prev_max; + // pass 1: compute final max across the full KV tile in chunks + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope); + final_max = subgroupMax(max(final_max, softmax_term)); + } + + var total_exp_term: f32 = 0.0; + // pass 2: compute exp sum and write P using final_max + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope); + let cur_p = select(0.0, + exp(softmax_term - final_max), + kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); + total_exp_term += subgroupAdd(cur_p); + if (kv_idx < KV_TILE) { + inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p); + } + } + + let cur_exp = exp(prev_max - final_max); + + if (sg_inv_id == 0) { + row_max_shmem[q_tile_row] = final_max; + exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term; + } + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + let idx = q_tile_row * HEAD_DIM_V + elem_idx; + o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp); + } + } + + // load v tile into shared memory +#if defined(KV_Q4_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let v_row = blck_idx / BLOCKS_V; + let global_v_row = kv_tile + v_row; + let block_k = blck_idx % BLOCKS_V; + let row_offset = v_row * HEAD_DIM_V; + + if (global_v_row < params.seq_len_kv) { + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = V[base_idx]; // scale + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = V[base_idx + 1u + block_offset + j]; + let q_1 = V[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_lo; + kv_shmem[row_offset + idx + 16u] = q_hi; + } + } + } + } +#elif defined(KV_Q8_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let v_row = blck_idx / BLOCKS_V; + let global_v_row = kv_tile + v_row; + let block_k = blck_idx % BLOCKS_V; + let row_offset = v_row * HEAD_DIM_V; + + if (global_v_row < params.seq_len_kv) { + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = V[base_idx]; // scale + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = V[base_idx + 1u + block_offset + j]; + let q_1 = V[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f16(q_byte) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_val; + } + } + } + } +#elif defined(KV_DIRECT) + // Direct global loads for KV +#else + for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) { + let v_row = elem_idx / HEAD_DIM_V; + let v_col = elem_idx % HEAD_DIM_V; + let global_v_row = kv_tile + v_row; + let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; + kv_shmem[elem_idx] = f16(select( + 0.0, + V[global_v_row_offset + v_col], + global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V)); + } +#endif + + workgroupBarrier(); + + // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem + // we want to compute O += P * V across the full KV tile + for (var head_dim_block = subgroup_id * SG_MAT_N; + head_dim_block < HEAD_DIM_V; + head_dim_block += num_subgroups * SG_MAT_N) { + // load O submatrix from shared memory + var o_sg_mat: subgroup_matrix_result = subgroupMatrixLoad>( + &o_shmem, + head_dim_block, + false, + HEAD_DIM_V + ); + + for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) { + let p_offset = kv_block * SG_MAT_N; + var p_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( + &inter_shmem, + p_offset, + false, + KV_TILE + ); + + // load V submatrix from global or shared memory +#ifdef KV_DIRECT + let v_block_row = kv_tile + kv_block * SG_MAT_N; + let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block; + var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + &V, + v_global_offset, + false, + params.stride_v1 + ); +#else + let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V; + var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + &kv_shmem, + v_block_offset + head_dim_block, + false, + HEAD_DIM_V + ); +#endif + // O += P * V + o_sg_mat = subgroupMatrixMultiplyAccumulate(p_sg_mat, v_sg_mat, o_sg_mat); + } + + // store O back to shared memory + subgroupMatrixStore(&o_shmem, head_dim_block, o_sg_mat, false, HEAD_DIM_V); + } + + workgroupBarrier(); + } + +#ifdef SINKS + // add sinks (applied once after processing all KV tiles) + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { + // no need to process rows beyond seq_len_q + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } + + var prev_max = row_max_shmem[q_tile_row]; + + // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum + let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0); + let new_max = subgroupMax(max(prev_max, sink_val)); + let max_exp = exp(prev_max - new_max); + let sink_exp = exp(sink_val - new_max); + + let sink_exp_sum = subgroupAdd(sink_exp); + + if (sg_inv_id == 0) { + exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum; + } + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + let idx = q_tile_row * HEAD_DIM_V + elem_idx; + let val = f32(o_shmem[idx]) * max_exp; + o_shmem[idx] = f16(val); + } + } + + workgroupBarrier(); +#endif + + // write output back to global memory + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } + + let exp_sum = exp_sum_shmem[q_tile_row]; + let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0); + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + let o_val = o_shmem[q_tile_row * HEAD_DIM_V + elem_idx]; + let scaled = f32(o_val) * scale; + dst[dst_global_offset + q_tile_row * dst2_stride + elem_idx] = scaled; + } + } +} From 480160d47297df43b43746294963476fc0a6e10f Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Fri, 9 Jan 2026 01:36:42 +0900 Subject: [PATCH 25/27] ggml-webgpu: Fix GGML_MEM_ALIGN to 8 for emscripten. (#18628) * Fix GGML_MEM_ALIGN to 8 for emscripten. * Add a comment explaining the need for GGML_MEM_ALIGN == 8 in 64-bit wasm with emscripten --- ggml/include/ggml.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 20c912d0e9..b69583dd3f 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -234,6 +234,11 @@ #if UINTPTR_MAX == 0xFFFFFFFF #define GGML_MEM_ALIGN 4 +#elif defined(__EMSCRIPTEN__) +// emscripten uses max_align_t == 8, so we need GGML_MEM_ALIGN == 8 for 64-bit wasm. +// (for 32-bit wasm, the first conditional is true and GGML_MEM_ALIGN stays 4.) +// ref: https://github.com/ggml-org/llama.cpp/pull/18628 + #define GGML_MEM_ALIGN 8 #else #define GGML_MEM_ALIGN 16 #endif From 046d5fd44e3505ab9c6d065ab65541fc2fdfd4f2 Mon Sep 17 00:00:00 2001 From: Aaron Teo Date: Fri, 9 Jan 2026 05:34:56 +0800 Subject: [PATCH 26/27] llama: use host memory if device reports 0 memory (#18587) --- ggml/src/ggml-backend-impl.h | 2 +- ggml/src/ggml-opencl/ggml-opencl.cpp | 4 ++-- src/llama-model.cpp | 16 ++++++++++++---- src/llama.cpp | 14 +++++++++++++- 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h index 6792ba986e..59190b7c46 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -144,7 +144,7 @@ extern "C" { // device description: short informative description of the device, could be the model name const char * (*get_description)(ggml_backend_dev_t dev); - // device memory in bytes + // device memory in bytes: 0 bytes to indicate no memory to report void (*get_memory)(ggml_backend_dev_t dev, size_t * free, size_t * total); // device type diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 472e2df50a..e50ca8e0f2 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -4287,8 +4287,8 @@ static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_ } static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { - *free = 1; - *total = 1; + *free = 0; + *total = 0; GGML_UNUSED(dev); } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 7ac59846bb..5de6493b9e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2452,6 +2452,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { pimpl->gpu_buft_list.emplace(dev, std::move(buft_list)); } + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } + // calculate the split points bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + n_devices(), [](float x) { return x == 0.0f; }); std::vector splits(n_devices()); @@ -2462,6 +2467,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { size_t total; size_t free; ggml_backend_dev_memory(dev, &free, &total); + + // devices can return 0 bytes for free and total memory if they do not + // have any to report. in this case, we will use the host memory as a fallback + // fixes: https://github.com/ggml-org/llama.cpp/issues/18577 + if (free == 0 && total == 0) { + ggml_backend_dev_memory(cpu_dev, &free, &total); + } splits[i] = free; } } else { @@ -2478,10 +2490,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { splits[i] /= split_sum; } - ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - if (cpu_dev == nullptr) { - throw std::runtime_error(format("%s: no CPU backend found", __func__)); - } const int i_gpu_start = std::max(int(hparams.n_layer) + 1 - n_gpu_layers, 0); const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, int(n_layer) + 1); auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { diff --git a/src/llama.cpp b/src/llama.cpp index 33f51a2389..f1096d960e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -111,8 +111,20 @@ static std::vector llama_get_device_memory_data( } } for (size_t i = 0; i < ret.size(); i++) { - size_t free, total; + size_t free; + size_t total; ggml_backend_dev_memory(model->devices[i], &free, &total); + + // devices can return 0 bytes for free and total memory if they do not + // have any to report. in this case, we will use the host memory as a fallback + // fixes: https://github.com/ggml-org/llama.cpp/issues/18577 + if (free == 0 && total == 0) { + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } + ggml_backend_dev_memory(cpu_dev, &free, &total); + } ret[i].free = free; ret[i].total = total; } From 8ece3836b400dd8d89021ad2cc6e57843ced8378 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Thu, 8 Jan 2026 22:35:40 +0100 Subject: [PATCH 27/27] common: support remote preset (#18520) * arg: support remote preset * proof reading * allow one HF repo to point to multiple HF repos * docs: mention about multiple GGUF use case * correct clean_file_name * download: also return HTTP status code * fix case with cache file used * fix --offline option --- common/arg.cpp | 167 ++++++++++++++++++++++++++++++-------------- common/download.cpp | 83 ++++++++++++++-------- common/download.h | 8 +++ common/preset.cpp | 77 +++++++++++++++++++- common/preset.h | 11 ++- docs/preset.md | 60 ++++++++++++++++ 6 files changed, 324 insertions(+), 82 deletions(-) create mode 100644 docs/preset.md diff --git a/common/arg.cpp b/common/arg.cpp index 9c0e6fbe78..72750a3cba 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -6,6 +6,7 @@ #include "log.h" #include "sampling.h" #include "download.h" +#include "preset.h" // fix problem with std::min and std::max #if defined(_WIN32) @@ -268,6 +269,46 @@ static void parse_tensor_buffer_overrides(const std::string & value, std::vector } } +static std::string clean_file_name(const std::string & fname) { + std::string clean_fname = fname; + string_replace_all(clean_fname, "\\", "_"); + string_replace_all(clean_fname, "/", "_"); + return clean_fname; +} + +static bool common_params_handle_remote_preset(common_params & params, llama_example ex) { + GGML_ASSERT(!params.model.hf_repo.empty()); + + const bool offline = params.offline; + std::string model_endpoint = get_model_endpoint(); + auto preset_url = model_endpoint + params.model.hf_repo + "/resolve/main/preset.ini"; + + // prepare local path for caching + auto preset_fname = clean_file_name(params.model.hf_repo + "_preset.ini"); + auto preset_path = fs_get_cache_file(preset_fname); + const int status = common_download_file_single(preset_url, preset_path, params.hf_token, offline); + const bool has_preset = status >= 200 && status < 400; + + // remote preset is optional, so we don't error out if not found + if (has_preset) { + LOG_INF("applying remote preset from %s\n", preset_url.c_str()); + common_preset_context ctx(ex, /* only_remote_allowed */ true); + common_preset global; // unused for now + auto remote_presets = ctx.load_from_ini(preset_path, global); + if (remote_presets.find(COMMON_PRESET_DEFAULT_NAME) != remote_presets.end()) { + common_preset & preset = remote_presets.at(COMMON_PRESET_DEFAULT_NAME); + LOG_INF("\n%s", preset.to_ini().c_str()); // to_ini already added trailing newline + preset.apply_to_params(params); + } else { + throw std::runtime_error("Remote preset.ini does not contain [" + std::string(COMMON_PRESET_DEFAULT_NAME) + "] section"); + } + } else { + LOG_INF("%s", "no remote preset found, skipping\n"); + } + + return has_preset; +} + struct handle_model_result { bool found_mmproj = false; common_params_model mmproj; @@ -309,9 +350,7 @@ static handle_model_result common_params_handle_model( // make sure model path is present (for caching purposes) if (model.path.empty()) { // this is to avoid different repo having same file name, or same file name in different subdirs - std::string filename = model.hf_repo + "_" + model.hf_file; - // to make sure we don't have any slashes in the filename - string_replace_all(filename, "/", "_"); + std::string filename = clean_file_name(model.hf_repo + "_" + model.hf_file); model.path = fs_get_cache_file(filename); } @@ -425,61 +464,87 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context } }; - std::set seen_args; + auto parse_cli_args = [&]() { + std::set seen_args; - for (int i = 1; i < argc; i++) { - const std::string arg_prefix = "--"; + for (int i = 1; i < argc; i++) { + const std::string arg_prefix = "--"; - std::string arg = argv[i]; - if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { - std::replace(arg.begin(), arg.end(), '_', '-'); - } - if (arg_to_options.find(arg) == arg_to_options.end()) { - throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str())); - } - if (!seen_args.insert(arg).second) { - LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str()); - } - auto & tmp = arg_to_options[arg]; - auto opt = *tmp.first; - bool is_positive = tmp.second; - if (opt.has_value_from_env()) { - fprintf(stderr, "warn: %s environment variable is set, but will be overwritten by command line argument %s\n", opt.env, arg.c_str()); - } - try { - if (opt.handler_void) { - opt.handler_void(params); - continue; + std::string arg = argv[i]; + if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { + std::replace(arg.begin(), arg.end(), '_', '-'); } - if (opt.handler_bool) { - opt.handler_bool(params, is_positive); - continue; + if (arg_to_options.find(arg) == arg_to_options.end()) { + throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str())); } + if (!seen_args.insert(arg).second) { + LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str()); + } + auto & tmp = arg_to_options[arg]; + auto opt = *tmp.first; + bool is_positive = tmp.second; + if (opt.has_value_from_env()) { + fprintf(stderr, "warn: %s environment variable is set, but will be overwritten by command line argument %s\n", opt.env, arg.c_str()); + } + try { + if (opt.handler_void) { + opt.handler_void(params); + continue; + } + if (opt.handler_bool) { + opt.handler_bool(params, is_positive); + continue; + } - // arg with single value - check_arg(i); - std::string val = argv[++i]; - if (opt.handler_int) { - opt.handler_int(params, std::stoi(val)); - continue; - } - if (opt.handler_string) { - opt.handler_string(params, val); - continue; - } + // arg with single value + check_arg(i); + std::string val = argv[++i]; + if (opt.handler_int) { + opt.handler_int(params, std::stoi(val)); + continue; + } + if (opt.handler_string) { + opt.handler_string(params, val); + continue; + } - // arg with 2 values - check_arg(i); - std::string val2 = argv[++i]; - if (opt.handler_str_str) { - opt.handler_str_str(params, val, val2); - continue; + // arg with 2 values + check_arg(i); + std::string val2 = argv[++i]; + if (opt.handler_str_str) { + opt.handler_str_str(params, val, val2); + continue; + } + } catch (std::exception & e) { + throw std::invalid_argument(string_format( + "error while handling argument \"%s\": %s\n\n" + "usage:\n%s\n\nto show complete usage, run with -h", + arg.c_str(), e.what(), opt.to_string().c_str())); } - } catch (std::exception & e) { - throw std::invalid_argument(string_format( - "error while handling argument \"%s\": %s\n\n" - "usage:\n%s\n\nto show complete usage, run with -h", - arg.c_str(), e.what(), opt.to_string().c_str())); + } + }; + + // parse the first time to get -hf option (used for remote preset) + parse_cli_args(); + + // maybe handle remote preset + if (!params.model.hf_repo.empty()) { + std::string cli_hf_repo = params.model.hf_repo; + bool has_preset = common_params_handle_remote_preset(params, ctx_arg.ex); + + // special case: if hf_repo explicitly set by preset, we need to preserve it (ignore CLI value) + // this is useful when we have one HF repo pointing to other HF repos (one model - multiple GGUFs) + std::string preset_hf_repo = params.model.hf_repo; + bool preset_has_hf_repo = preset_hf_repo != cli_hf_repo; + + if (has_preset) { + // re-parse CLI args to override preset values + parse_cli_args(); + } + + // preserve hf_repo from preset if needed + if (preset_has_hf_repo) { + params.model.hf_repo = preset_hf_repo; } } diff --git a/common/download.cpp b/common/download.cpp index 6f56b5518f..a1e0e518e9 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -157,6 +157,10 @@ static std::string read_etag(const std::string & path) { return none; } +static bool is_http_status_ok(int status) { + return status >= 200 && status < 400; +} + #ifdef LLAMA_USE_CURL // @@ -306,12 +310,14 @@ static bool common_download_head(CURL * curl, } // download one single file from remote URL to local path -static bool common_download_file_single_online(const std::string & url, +// returns status code or -1 on error +static int common_download_file_single_online(const std::string & url, const std::string & path, const std::string & bearer_token, const common_header_list & custom_headers) { static const int max_attempts = 3; static const int retry_delay_seconds = 2; + for (int i = 0; i < max_attempts; ++i) { std::string etag; @@ -371,7 +377,7 @@ static bool common_download_file_single_online(const std::string & url, LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); if (remove(path.c_str()) != 0) { LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); - return false; + return -1; } } @@ -380,14 +386,14 @@ static bool common_download_file_single_online(const std::string & url, if (std::filesystem::exists(path_temporary)) { if (remove(path_temporary.c_str()) != 0) { LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str()); - return false; + return -1; } } if (std::filesystem::exists(path)) { if (remove(path.c_str()) != 0) { LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); - return false; + return -1; } } } @@ -414,23 +420,27 @@ static bool common_download_file_single_online(const std::string & url, long http_code = 0; curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); - if (http_code < 200 || http_code >= 400) { + + int status = static_cast(http_code); + if (!is_http_status_ok(http_code)) { LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code); - return false; + return status; // TODO: maybe only return on certain codes } if (rename(path_temporary.c_str(), path.c_str()) != 0) { LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); - return false; + return -1; } + + return static_cast(http_code); } else { LOG_INF("%s: using cached file: %s\n", __func__, path.c_str()); - } - break; + return 304; // Not Modified - fake cached response + } } - return true; + return -1; // max attempts reached } std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params) { @@ -625,7 +635,8 @@ static bool common_pull_file(httplib::Client & cli, } // download one single file from remote URL to local path -static bool common_download_file_single_online(const std::string & url, +// returns status code or -1 on error +static int common_download_file_single_online(const std::string & url, const std::string & path, const std::string & bearer_token, const common_header_list & custom_headers) { @@ -659,8 +670,10 @@ static bool common_download_file_single_online(const std::string & url, LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1); if (file_exists) { LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str()); - return true; + return 304; // 304 Not Modified - fake cached response } + return head->status; // cannot use cached file, return raw status code + // TODO: maybe retry only on certain codes } std::string etag; @@ -692,12 +705,12 @@ static bool common_download_file_single_online(const std::string & url, if (file_exists) { if (!should_download_from_scratch) { LOG_INF("%s: using cached file: %s\n", __func__, path.c_str()); - return true; + return 304; // 304 Not Modified - fake cached response } LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); if (remove(path.c_str()) != 0) { LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); - return false; + return -1; } } @@ -709,7 +722,7 @@ static bool common_download_file_single_online(const std::string & url, existing_size = std::filesystem::file_size(path_temporary); } else if (remove(path_temporary.c_str()) != 0) { LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str()); - return false; + return -1; } } @@ -730,15 +743,16 @@ static bool common_download_file_single_online(const std::string & url, if (std::rename(path_temporary.c_str(), path.c_str()) != 0) { LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); - return false; + return -1; } if (!etag.empty()) { write_etag(path, etag); } - break; + + return head->status; // TODO: use actual GET status? } - return true; + return -1; // max attempts reached } std::pair> common_remote_get_content(const std::string & url, @@ -777,22 +791,22 @@ std::pair> common_remote_get_content(const std::string #if defined(LLAMA_USE_CURL) || defined(LLAMA_USE_HTTPLIB) -static bool common_download_file_single(const std::string & url, - const std::string & path, - const std::string & bearer_token, - bool offline, - const common_header_list & headers) { +int common_download_file_single(const std::string & url, + const std::string & path, + const std::string & bearer_token, + bool offline, + const common_header_list & headers) { if (!offline) { return common_download_file_single_online(url, path, bearer_token, headers); } if (!std::filesystem::exists(path)) { LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str()); - return false; + return -1; } LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str()); - return true; + return 304; // Not Modified - fake cached response } // download multiple files from remote URLs to local paths @@ -810,7 +824,8 @@ static bool common_download_file_multiple(const std::vector & it) -> bool { - return common_download_file_single(it.first, it.second, bearer_token, offline, headers); + const int http_status = common_download_file_single(it.first, it.second, bearer_token, offline, headers); + return is_http_status_ok(http_status); }, item ) @@ -837,7 +852,8 @@ bool common_download_model(const common_params_model & model, return false; } - if (!common_download_file_single(model.url, model.path, bearer_token, offline, headers)) { + const int http_status = common_download_file_single(model.url, model.path, bearer_token, offline, headers); + if (!is_http_status_ok(http_status)) { return false; } @@ -975,7 +991,7 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, } else if (res_code == 401) { throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token"); } else { - throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str())); + throw std::runtime_error(string_format("error from HF API (%s), response code: %ld, data: %s", url.c_str(), res_code, res_str.c_str())); } // check response @@ -1094,7 +1110,8 @@ std::string common_docker_resolve_model(const std::string & docker) { std::string local_path = fs_get_cache_file(model_filename); const std::string blob_url = url_prefix + "/blobs/" + gguf_digest; - if (!common_download_file_single(blob_url, local_path, token, false, {})) { + const int http_status = common_download_file_single(blob_url, local_path, token, false, {}); + if (!is_http_status_ok(http_status)) { throw std::runtime_error("Failed to download Docker Model"); } @@ -1120,6 +1137,14 @@ std::string common_docker_resolve_model(const std::string &) { throw std::runtime_error("download functionality is not enabled in this build"); } +int common_download_file_single(const std::string &, + const std::string &, + const std::string &, + bool, + const common_header_list &) { + throw std::runtime_error("download functionality is not enabled in this build"); +} + #endif // LLAMA_USE_CURL || LLAMA_USE_HTTPLIB std::vector common_list_cached_models() { diff --git a/common/download.h b/common/download.h index 9ea2093939..c79be2f90e 100644 --- a/common/download.h +++ b/common/download.h @@ -65,6 +65,14 @@ bool common_download_model( // returns list of cached models std::vector common_list_cached_models(); +// download single file from url to local path +// returns status code or -1 on error +int common_download_file_single(const std::string & url, + const std::string & path, + const std::string & bearer_token, + bool offline, + const common_header_list & headers = {}); + // resolve and download model from Docker registry // return local path to downloaded model file std::string common_docker_resolve_model(const std::string & docker); diff --git a/common/preset.cpp b/common/preset.cpp index e2fc18c5da..aec14e0769 100644 --- a/common/preset.cpp +++ b/common/preset.cpp @@ -16,6 +16,46 @@ static std::string rm_leading_dashes(const std::string & str) { return str.substr(pos); } +// only allow a subset of args for remote presets for security reasons +// do not add more args unless absolutely necessary +// args that output to files are strictly prohibited +static std::set get_remote_preset_whitelist(const std::map & key_to_opt) { + static const std::set allowed_options = { + "model-url", + "hf-repo", + "hf-repo-draft", + "hf-repo-v", // vocoder + "hf-file-v", // vocoder + "mmproj-url", + "pooling", + "jinja", + "batch-size", + "ubatch-size", + "cache-reuse", + // note: sampling params are automatically allowed by default + // negated args will be added automatically + }; + + std::set allowed_keys; + + for (const auto & it : key_to_opt) { + const std::string & key = it.first; + const common_arg & opt = it.second; + if (allowed_options.find(key) != allowed_options.end() || opt.is_sparam) { + allowed_keys.insert(key); + // also add variant keys (args without leading dashes and env vars) + for (const auto & arg : opt.get_args()) { + allowed_keys.insert(rm_leading_dashes(arg)); + } + for (const auto & env : opt.get_env()) { + allowed_keys.insert(env); + } + } + } + + return allowed_keys; +} + std::vector common_preset::to_args(const std::string & bin_path) const { std::vector args; @@ -121,6 +161,29 @@ void common_preset::merge(const common_preset & other) { } } +void common_preset::apply_to_params(common_params & params) const { + for (const auto & [opt, val] : options) { + // apply each option to params + if (opt.handler_string) { + opt.handler_string(params, val); + } else if (opt.handler_int) { + opt.handler_int(params, std::stoi(val)); + } else if (opt.handler_bool) { + opt.handler_bool(params, common_arg_utils::is_truthy(val)); + } else if (opt.handler_str_str) { + // not supported yet + throw std::runtime_error(string_format( + "%s: option with two values is not supported yet", + __func__ + )); + } else if (opt.handler_void) { + opt.handler_void(params); + } else { + GGML_ABORT("unknown handler type"); + } + } +} + static std::map> parse_ini_from_file(const std::string & path) { std::map> parsed; @@ -230,10 +293,16 @@ static std::string parse_bool_arg(const common_arg & arg, const std::string & ke return value; } -common_preset_context::common_preset_context(llama_example ex) +common_preset_context::common_preset_context(llama_example ex, bool only_remote_allowed) : ctx_params(common_params_parser_init(default_params, ex)) { common_params_add_preset_options(ctx_params.options); key_to_opt = get_map_key_opt(ctx_params); + + // setup allowed keys if only_remote_allowed is true + if (only_remote_allowed) { + filter_allowed_keys = true; + allowed_keys = get_remote_preset_whitelist(key_to_opt); + } } common_presets common_preset_context::load_from_ini(const std::string & path, common_preset & global) const { @@ -250,6 +319,12 @@ common_presets common_preset_context::load_from_ini(const std::string & path, co LOG_DBG("loading preset: %s\n", preset.name.c_str()); for (const auto & [key, value] : section.second) { LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str()); + if (filter_allowed_keys && allowed_keys.find(key) == allowed_keys.end()) { + throw std::runtime_error(string_format( + "option '%s' is not allowed in remote presets", + key.c_str() + )); + } if (key_to_opt.find(key) != key_to_opt.end()) { const auto & opt = key_to_opt.at(key); if (is_bool_arg(opt)) { diff --git a/common/preset.h b/common/preset.h index 3a84d1be29..11ba6ef812 100644 --- a/common/preset.h +++ b/common/preset.h @@ -6,6 +6,7 @@ #include #include #include +#include // // INI preset parser and writer @@ -40,6 +41,9 @@ struct common_preset { // merge another preset into this one, overwriting existing options void merge(const common_preset & other); + + // apply preset options to common_params + void apply_to_params(common_params & params) const; }; // interface for multiple presets in one file @@ -50,7 +54,12 @@ struct common_preset_context { common_params default_params; // unused for now common_params_context ctx_params; std::map key_to_opt; - common_preset_context(llama_example ex); + + bool filter_allowed_keys = false; + std::set allowed_keys; + + // if only_remote_allowed is true, only accept whitelisted keys + common_preset_context(llama_example ex, bool only_remote_allowed = false); // load presets from INI file common_presets load_from_ini(const std::string & path, common_preset & global) const; diff --git a/docs/preset.md b/docs/preset.md new file mode 100644 index 0000000000..be50bb9926 --- /dev/null +++ b/docs/preset.md @@ -0,0 +1,60 @@ +# llama.cpp INI Presets + +## Introduction + +The INI preset feature, introduced in [PR#17859](https://github.com/ggml-org/llama.cpp/pull/17859), allows users to create reusable and shareable parameter configurations for llama.cpp. + +### Using Presets with the Server + +When running multiple models on the server (router mode), INI preset files can be used to configure model-specific parameters. Please refer to the [server documentation](../tools/server/README.md) for more details. + +### Using a Remote Preset + +> [!NOTE] +> +> This feature is currently only supported via the `-hf` option. + +For GGUF models hosted on Hugging Face, you can include a `preset.ini` file in the root directory of the repository to define specific configurations for that model. + +Example: + +```ini +hf-repo-draft = username/my-draft-model-GGUF +temp = 0.5 +top-k = 20 +top-p = 0.95 +``` + +For security reasons, only certain options are allowed. Please refer to [preset.cpp](../common/preset.cpp) for the complete list of permitted options. + +Example usage: + +Assuming your repository `username/my-model-with-preset` contains a `preset.ini` with the configuration above: + +```sh +llama-cli -hf username/my-model-with-preset + +# This is equivalent to: +llama-cli -hf username/my-model-with-preset \ + --hf-repo-draft username/my-draft-model-GGUF \ + --temp 0.5 \ + --top-k 20 \ + --top-p 0.95 +``` + +You can also override preset arguments by specifying them on the command line: + +```sh +# Force temp = 0.1, overriding the preset value +llama-cli -hf username/my-model-with-preset --temp 0.1 +``` + +If you want to define multiple preset configurations for one or more GGUF models, you can create a blank HF repo for each preset. Each HF repo should contain a `preset.ini` file that references the actual model(s): + +```ini +hf-repo = user/my-model-main +hf-repo-draft = user/my-model-draft +temp = 0.8 +ctx-size = 1024 +; (and other configurations) +```