diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c85999b89a..9e0c4a6752 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -234,7 +234,7 @@ jobs: strategy: matrix: sanitizer: [ADDRESS, THREAD, UNDEFINED] - build_type: [Debug, Release] + build_type: [Debug] steps: - name: Clone diff --git a/.gitignore b/.gitignore index 1df7cf4a15..694f36e042 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ *.metallib *.o *.so +*.swp *.tmp # IDE / OS diff --git a/CMakeLists.txt b/CMakeLists.txt index 458b65603a..0bd767a62c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -83,11 +83,8 @@ include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake) include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/common.cmake) # override ggml options -set(GGML_SANITIZE_THREAD ${LLAMA_SANITIZE_THREAD}) -set(GGML_SANITIZE_ADDRESS ${LLAMA_SANITIZE_ADDRESS}) -set(GGML_SANITIZE_UNDEFINED ${LLAMA_SANITIZE_UNDEFINED}) -set(GGML_ALL_WARNINGS ${LLAMA_ALL_WARNINGS}) -set(GGML_FATAL_WARNINGS ${LLAMA_FATAL_WARNINGS}) +set(GGML_ALL_WARNINGS ${LLAMA_ALL_WARNINGS}) +set(GGML_FATAL_WARNINGS ${LLAMA_FATAL_WARNINGS}) # change the default for these ggml options if (NOT DEFINED GGML_LLAMAFILE) @@ -118,16 +115,62 @@ llama_option_depr(WARNING LLAMA_SYCL_F16 GGML_SYCL_F16) llama_option_depr(WARNING LLAMA_CANN GGML_CANN) llama_option_depr(WARNING LLAMA_QNN GGML_QNN) +if (NOT MSVC) + if (LLAMA_SANITIZE_THREAD) + message(STATUS "Using -fsanitize=thread") + + add_compile_options(-fsanitize=thread) + link_libraries (-fsanitize=thread) + endif() + + if (LLAMA_SANITIZE_ADDRESS) + message(STATUS "Using -fsanitize=address") + + add_compile_options(-fsanitize=address -fno-omit-frame-pointer) + link_libraries (-fsanitize=address) + endif() + + if (LLAMA_SANITIZE_UNDEFINED) + message(STATUS "Using -fsanitize=undefined") + + add_compile_options(-fsanitize=undefined) + link_libraries (-fsanitize=undefined) + endif() +endif() + # -# build the library +# 3rd-party # if (NOT TARGET ggml) add_subdirectory(ggml) # ... otherwise assume ggml is added by a parent CMakeLists.txt endif() + +# +# build the library +# + add_subdirectory(src) +# +# utils, programs, examples and tests +# + +if (LLAMA_BUILD_COMMON) + add_subdirectory(common) +endif() + +if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION) + include(CTest) + add_subdirectory(tests) +endif() + +if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_EXAMPLES) + add_subdirectory(examples) + add_subdirectory(pocs) +endif() + # # install # @@ -201,21 +244,3 @@ configure_file(cmake/llama.pc.in install(FILES "${CMAKE_CURRENT_BINARY_DIR}/llama.pc" DESTINATION lib/pkgconfig) - -# -# utils, programs, examples and tests -# - -if (LLAMA_BUILD_COMMON) - add_subdirectory(common) -endif() - -if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION) - include(CTest) - add_subdirectory(tests) -endif() - -if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_EXAMPLES) - add_subdirectory(examples) - add_subdirectory(pocs) -endif() diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5a85ec5d2e..8d411982b4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,10 +1,10 @@ # Pull requests (for contributors) - Test your changes: - - Execute [the full CI locally on your machine](ci/README.md) before publishing - - Verify that the perplexity and the performance are not affected negatively by your changes (use `llama-perplexity` and `llama-bench`) - - If you modified the `ggml` source, run the `test-backend-ops` tool to check whether different backend implementations of the `ggml` operators produce consistent results (this requires access to at least two different `ggml` backends) - - If you modified a `ggml` operator or added a new one, add the corresponding test cases to `test-backend-ops` + - Execute [the full CI locally on your machine](ci/README.md) before publishing + - Verify that the perplexity and the performance are not affected negatively by your changes (use `llama-perplexity` and `llama-bench`) + - If you modified the `ggml` source, run the `test-backend-ops` tool to check whether different backend implementations of the `ggml` operators produce consistent results (this requires access to at least two different `ggml` backends) + - If you modified a `ggml` operator or added a new one, add the corresponding test cases to `test-backend-ops` - Consider allowing write access to your branch for faster reviews, as reviewers can push commits directly - If your PR becomes stale, don't hesitate to ping the maintainers in the comments @@ -20,14 +20,104 @@ - Avoid adding third-party dependencies, extra files, extra headers, etc. - Always consider cross-compatibility with other operating systems and architectures - Avoid fancy-looking modern STL constructs, use basic `for` loops, avoid templates, keep it simple -- There are no strict rules for the code style, but try to follow the patterns in the code (indentation, spaces, etc.). Vertical alignment makes things more readable and easier to batch edit +- Vertical alignment makes things more readable and easier to batch edit - Clean-up any trailing whitespaces, use 4 spaces for indentation, brackets on the same line, `void * ptr`, `int & a` -- Naming usually optimizes for common prefix (see https://github.com/ggerganov/ggml/pull/302#discussion_r1243240963) +- Use sized integer types such as `int32_t` in the public API, e.g. `size_t` may also be appropriate for allocation sizes or byte offsets +- Declare structs with `struct foo {}` instead of `typedef struct foo {} foo` + - In C++ code omit optional `struct` and `enum` keyword whenever they are not necessary + ```cpp + // OK + llama_context * ctx; + const llama_rope_type rope_type; + + // not OK + struct llama_context * ctx; + const enum llama_rope_type rope_type; + ``` + + _(NOTE: this guideline is yet to be applied to the `llama.cpp` codebase. New code should follow this guideline.)_ + +- Try to follow the existing patterns in the code (indentation, spaces, etc.). In case of doubt use `clang-format` to format the added code +- For anything not covered in the current guidelines, refer to the [C++ Core Guidelines](https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines) - Tensors store data in row-major order. We refer to dimension 0 as columns, 1 as rows, 2 as matrices - Matrix multiplication is unconventional: [`C = ggml_mul_mat(ctx, A, B)`](https://github.com/ggerganov/llama.cpp/blob/880e352277fc017df4d5794f0c21c44e1eae2b84/ggml.h#L1058-L1064) means $C^T = A B^T \Leftrightarrow C = B A^T.$ ![matmul](media/matmul.png) +# Naming guidelines + +- Use `snake_case` for function, variable and type names +- Naming usually optimizes for longest common prefix (see https://github.com/ggerganov/ggml/pull/302#discussion_r1243240963) + + ```cpp + // not OK + int small_number; + int big_number; + + // OK + int number_small; + int number_big; + ``` + +- Enum values are always in upper case and prefixed with the enum name + + ```cpp + enum llama_vocab_type { + LLAMA_VOCAB_TYPE_NONE = 0, + LLAMA_VOCAB_TYPE_SPM = 1, + LLAMA_VOCAB_TYPE_BPE = 2, + LLAMA_VOCAB_TYPE_WPM = 3, + LLAMA_VOCAB_TYPE_UGM = 4, + LLAMA_VOCAB_TYPE_RWKV = 5, + }; + ``` + +- The general naming pattern is `_`, with `` being `_` + + ```cpp + llama_model_init(); // class: "llama_model", method: "init" + llama_sampler_chain_remove(); // class: "llama_sampler_chain", method: "remove" + llama_sampler_get_seed(); // class: "llama_sampler", method: "get_seed" + llama_set_embeddings(); // class: "llama_context", method: "set_embeddings" + llama_n_threads(); // class: "llama_context", method: "n_threads" + llama_adapter_lora_free(); // class: "llama_adapter_lora", method: "free" + ``` + + - The `get` `` can be omitted + - The `` can be omitted if not necessary + - The `_context` suffix of the `` is optional. Use it to disambiguate symbols when needed + - Use `init`/`free` for constructor/destructor `` + +- Use the `_t` suffix when a type is supposed to be opaque to the user - it's not relevant to them if it is a struct or anything else + + ```cpp + typedef struct llama_context * llama_context_t; + + enum llama_pooling_type llama_pooling_type(const llama_context_t ctx); + ``` + + _(NOTE: this guideline is yet to be applied to the `llama.cpp` codebase. New code should follow this guideline)_ + +- C/C++ filenames are all lowercase with dashes. Headers use the `.h` extension. Source files use the `.c` or `.cpp` extension +- Python filenames are all lowercase with underscores + +- _(TODO: abbreviations usage)_ + +# Preprocessor directives + +- _(TODO: add guidelines with examples and apply them to the codebase)_ + + ```cpp + #ifdef FOO + #endif // FOO + ``` + +# Documentation + +- Documentation is a community effort +- When you need to look into the source code to figure out how to use an API consider adding a short summary to the header file for future reference +- When you notice incorrect or outdated documentation, please update it + # Resources The Github issues, PRs and discussions contain a lot of information that can be useful to get familiar with the codebase. For convenience, some of the more important information is referenced from Github projects: diff --git a/README.md b/README.md index 6302ac977a..784669ce13 100644 --- a/README.md +++ b/README.md @@ -204,6 +204,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - [GPUStack](https://github.com/gpustack/gpustack) - Manage GPU clusters for running LLMs - [llama_cpp_canister](https://github.com/onicai/llama_cpp_canister) - llama.cpp as a smart contract on the Internet Computer, using WebAssembly - [llama-swap](https://github.com/mostlygeek/llama-swap) - transparent proxy that adds automatic model switching with llama-server +- [Kalavai](https://github.com/kalavai-net/kalavai-client) - Crowdsource end to end LLM deployment at any scale @@ -245,6 +246,8 @@ The [Hugging Face](https://huggingface.co) platform hosts a [number of LLMs](htt - [Trending](https://huggingface.co/models?library=gguf&sort=trending) - [LLaMA](https://huggingface.co/models?sort=trending&search=llama+gguf) +You can either manually download the GGUF file or directly use any `llama.cpp`-compatible models from Hugging Face by using this CLI argument: `-hf /[:quant]` + After downloading a model, use the CLI tools to run it locally - see below. `llama.cpp` requires the model to be stored in the [GGUF](https://github.com/ggerganov/ggml/blob/master/docs/gguf.md) file format. Models in other data formats can be converted to GGUF using the `convert_*.py` Python scripts in this repo. @@ -263,21 +266,12 @@ To learn more about model quantization, [read this documentation](examples/quant #### A CLI tool for accessing and experimenting with most of `llama.cpp`'s functionality. -
- Run simple text completion - - ```bash - llama-cli -m model.gguf -p "I believe the meaning of life is" -n 128 - - # I believe the meaning of life is to find your own truth and to live in accordance with it. For me, this means being true to myself and following my passions, even if they don't align with societal expectations. I think that's what I love about yoga – it's not just a physical practice, but a spiritual one too. It's about connecting with yourself, listening to your inner voice, and honoring your own unique journey. - ``` - -
- --
Run in conversation mode + Models with a built-in chat template will automatically activate conversation mode. If this doesn't occur, you can manually enable it by adding `-cnv` and specifying a suitable chat template with `--chat-template NAME` + ```bash - llama-cli -m model.gguf -p "You are a helpful assistant" -cnv + llama-cli -m model.gguf # > hi, who are you? # Hi there! I'm your helpful assistant! I'm an AI-powered chatbot designed to assist and provide information to users like you. I'm here to help answer your questions, provide guidance, and offer support on a wide range of topics. I'm a friendly and knowledgeable AI, and I'm always happy to help with anything you need. What's on your mind, and how can I assist you today? @@ -289,17 +283,28 @@ To learn more about model quantization, [read this documentation](examples/quant
-
- Run with custom chat template + Run in conversation mode with custom chat template ```bash - # use the "chatml" template - llama-cli -m model.gguf -p "You are a helpful assistant" -cnv --chat-template chatml + # use the "chatml" template (use -h to see the list of supported templates) + llama-cli -m model.gguf -cnv --chat-template chatml # use a custom template - llama-cli -m model.gguf -p "You are a helpful assistant" -cnv --in-prefix 'User: ' --reverse-prompt 'User:' + llama-cli -m model.gguf -cnv --in-prefix 'User: ' --reverse-prompt 'User:' ``` - [Supported templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) +
+ +-
+ Run simple text completion + + To disable conversation mode explicitly, use `-no-cnv` + + ```bash + llama-cli -m model.gguf -p "I believe the meaning of life is" -n 128 -no-cnv + + # I believe the meaning of life is to find your own truth and to live in accordance with it. For me, this means being true to myself and following my passions, even if they don't align with societal expectations. I think that's what I love about yoga – it's not just a physical practice, but a spiritual one too. It's about connecting with yourself, listening to your inner voice, and honoring your own unique journey. + ```
diff --git a/ci/run.sh b/ci/run.sh index abf08a4ff6..77c32ce005 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -326,17 +326,17 @@ function gg_run_open_llama_7b_v2 { ./bin/llama-quantize ${model_f16} ${model_q5_k} q5_k ./bin/llama-quantize ${model_f16} ${model_q6_k} q6_k - (time ./bin/llama-cli --model ${model_f16} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log - (time ./bin/llama-cli --model ${model_q8_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log - (time ./bin/llama-cli --model ${model_q4_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log - (time ./bin/llama-cli --model ${model_q4_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log - (time ./bin/llama-cli --model ${model_q5_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log - (time ./bin/llama-cli --model ${model_q5_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log - (time ./bin/llama-cli --model ${model_q2_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log - (time ./bin/llama-cli --model ${model_q3_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log - (time ./bin/llama-cli --model ${model_q4_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log - (time ./bin/llama-cli --model ${model_q5_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log - (time ./bin/llama-cli --model ${model_q6_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log + (time ./bin/llama-cli -no-cnv --model ${model_f16} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log + (time ./bin/llama-cli -no-cnv --model ${model_q8_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log + (time ./bin/llama-cli -no-cnv --model ${model_q2_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q3_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q6_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log (time ./bin/llama-perplexity --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log (time ./bin/llama-perplexity --model ${model_q8_0} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log @@ -460,17 +460,17 @@ function gg_run_pythia_1_4b { ./bin/llama-quantize ${model_f16} ${model_q5_k} q5_k ./bin/llama-quantize ${model_f16} ${model_q6_k} q6_k - (time ./bin/llama-cli --model ${model_f16} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log - (time ./bin/llama-cli --model ${model_q8_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log - (time ./bin/llama-cli --model ${model_q4_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log - (time ./bin/llama-cli --model ${model_q4_1} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log - (time ./bin/llama-cli --model ${model_q5_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log - (time ./bin/llama-cli --model ${model_q5_1} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log - (time ./bin/llama-cli --model ${model_q2_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log - (time ./bin/llama-cli --model ${model_q3_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log - (time ./bin/llama-cli --model ${model_q4_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log - (time ./bin/llama-cli --model ${model_q5_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log - (time ./bin/llama-cli --model ${model_q6_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log + (time ./bin/llama-cli -no-cnv --model ${model_f16} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log + (time ./bin/llama-cli -no-cnv --model ${model_q8_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_1} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_1} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log + (time ./bin/llama-cli -no-cnv --model ${model_q2_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q3_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q6_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log (time ./bin/llama-perplexity --model ${model_f16} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log (time ./bin/llama-perplexity --model ${model_q8_0} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log @@ -591,17 +591,17 @@ function gg_run_pythia_2_8b { ./bin/llama-quantize ${model_f16} ${model_q5_k} q5_k ./bin/llama-quantize ${model_f16} ${model_q6_k} q6_k - (time ./bin/llama-cli --model ${model_f16} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log - (time ./bin/llama-cli --model ${model_q8_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log - (time ./bin/llama-cli --model ${model_q4_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log - (time ./bin/llama-cli --model ${model_q4_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log - (time ./bin/llama-cli --model ${model_q5_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log - (time ./bin/llama-cli --model ${model_q5_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log - (time ./bin/llama-cli --model ${model_q2_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log - (time ./bin/llama-cli --model ${model_q3_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log - (time ./bin/llama-cli --model ${model_q4_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log - (time ./bin/llama-cli --model ${model_q5_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log - (time ./bin/llama-cli --model ${model_q6_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log + (time ./bin/llama-cli -no-cnv --model ${model_f16} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log + (time ./bin/llama-cli -no-cnv --model ${model_q8_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log + (time ./bin/llama-cli -no-cnv --model ${model_q2_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q3_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q6_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log (time ./bin/llama-perplexity --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log (time ./bin/llama-perplexity --model ${model_q8_0} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log diff --git a/common/arg.cpp b/common/arg.cpp index 27886b84e8..dede335fbc 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -130,17 +130,26 @@ std::string common_arg::to_string() { static void common_params_handle_model_default( std::string & model, - std::string & model_url, + const std::string & model_url, std::string & hf_repo, - std::string & hf_file) { + std::string & hf_file, + const std::string & hf_token) { if (!hf_repo.empty()) { // short-hand to avoid specifying --hf-file -> default it to --model if (hf_file.empty()) { if (model.empty()) { - throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n"); + auto auto_detected = common_get_hf_file(hf_repo, hf_token); + if (auto_detected.first.empty() || auto_detected.second.empty()) { + exit(1); // built without CURL, error message already printed + } + hf_repo = auto_detected.first; + hf_file = auto_detected.second; + } else { + hf_file = model; } - hf_file = model; - } else if (model.empty()) { + } + // make sure model path is present (for caching purposes) + if (model.empty()) { // this is to avoid different repo having same file name, or same file name in different subdirs std::string filename = hf_repo + "_" + hf_file; // to make sure we don't have any slashes in the filename @@ -290,8 +299,8 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context } // TODO: refactor model params in a common struct - common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file); - common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file); + common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token); + common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token); if (params.escape) { string_process_escapes(params.prompt); @@ -367,6 +376,30 @@ static std::vector parse_device_list(const std::string & val return devices; } +static void add_rpc_devices(std::string servers) { + auto rpc_servers = string_split(servers, ','); + if (rpc_servers.empty()) { + throw std::invalid_argument("no RPC servers specified"); + } + ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC"); + if (!rpc_reg) { + throw std::invalid_argument("failed to find RPC backend"); + } + typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint); + ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device"); + if (!ggml_backend_rpc_add_device_fn) { + throw std::invalid_argument("failed to find RPC device add function"); + } + for (const auto & server : rpc_servers) { + ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str()); + if (dev) { + ggml_backend_device_register(dev); + } else { + throw std::invalid_argument("failed to register RPC device"); + } + } +} + bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) { auto ctx_arg = common_params_parser_init(params, ex, print_usage); const common_params params_org = ctx_arg.params; // the example can modify the default params @@ -768,15 +801,19 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-cnv", "--conversation"}, - string_format( - "run in conversation mode:\n" - "- does not print special tokens and suffix/prefix\n" - "- interactive mode is also enabled\n" - "(default: %s)", - params.conversation ? "true" : "false" - ), + "run in conversation mode:\n" + "- does not print special tokens and suffix/prefix\n" + "- interactive mode is also enabled\n" + "(default: auto enabled if chat template is available)", [](common_params & params) { - params.conversation = true; + params.conversation_mode = COMMON_CONVERSATION_MODE_ENABLED; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"-no-cnv", "--no-conversation"}, + "force disable conversation mode (default: false)", + [](common_params & params) { + params.conversation_mode = COMMON_CONVERSATION_MODE_DISABLED; } ).set_examples({LLAMA_EXAMPLE_MAIN})); add_opt(common_arg( @@ -1372,7 +1409,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--rpc"}, "SERVERS", "comma separated list of RPC servers", [](common_params & params, const std::string & value) { - params.rpc_servers = value; + add_rpc_devices(value); + GGML_UNUSED(params); } ).set_env("LLAMA_ARG_RPC")); } @@ -1583,21 +1621,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_env("LLAMA_ARG_MODEL_URL")); add_opt(common_arg( - {"-hfr", "--hf-repo"}, "REPO", - "Hugging Face model repository (default: unused)", + {"-hf", "-hfr", "--hf-repo"}, "/[:quant]", + "Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n" + "example: unsloth/phi-4-GGUF:q4_k_m\n" + "(default: unused)", [](common_params & params, const std::string & value) { params.hf_repo = value; } ).set_env("LLAMA_ARG_HF_REPO")); add_opt(common_arg( {"-hff", "--hf-file"}, "FILE", - "Hugging Face model file (default: unused)", + "Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)", [](common_params & params, const std::string & value) { params.hf_file = value; } ).set_env("LLAMA_ARG_HF_FILE")); add_opt(common_arg( - {"-hfrv", "--hf-repo-v"}, "REPO", + {"-hfv", "-hfrv", "--hf-repo-v"}, "/[:quant]", "Hugging Face model repository for the vocoder model (default: unused)", [](common_params & params, const std::string & value) { params.vocoder.hf_repo = value; @@ -2214,6 +2254,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.vocoder.model = value; } ).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--tts-use-guide-tokens"}, + "Use guide tokens to improve TTS word recall", + [](common_params & params) { + params.vocoder.use_guide_tokens = true; + } + ).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER})); // model-specific add_opt(common_arg( diff --git a/common/common.cpp b/common/common.cpp index 86e4e1e24e..451826d5d6 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -73,6 +73,22 @@ #include #endif #define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 + +// +// CURL utils +// + +using curl_ptr = std::unique_ptr; + +// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one +struct curl_slist_ptr { + struct curl_slist * ptr = nullptr; + ~curl_slist_ptr() { + if (ptr) { + curl_slist_free_all(ptr); + } + } +}; #endif // LLAMA_USE_CURL using json = nlohmann::ordered_json; @@ -857,21 +873,23 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } + const llama_vocab * vocab = llama_model_get_vocab(model); + if (params.reranking) { bool ok = true; - if (llama_token_bos(model) == LLAMA_TOKEN_NULL) { - LOG_WRN("%s: warning: model does not have a BOS token, reranking will not work\n", __func__); + if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__); ok = false; } - if (llama_token_eos(model) == LLAMA_TOKEN_NULL) { - LOG_WRN("%s: warning: model does not have an EOS token, reranking will not work\n", __func__); + if (llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__); ok = false; } - if (llama_token_sep(model) == LLAMA_TOKEN_NULL) { - LOG_WRN("%s: warning: model does not have a SEP token, reranking will not work\n", __func__); + if (llama_vocab_sep(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__); ok = false; } @@ -884,7 +902,7 @@ struct common_init_result common_init_from_params(common_params & params) { auto cparams = common_context_params_to_llama(params); - llama_context * lctx = llama_new_context_with_model(model, cparams); + llama_context * lctx = llama_init_from_model(model, cparams); if (lctx == NULL) { LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str()); llama_model_free(model); @@ -898,7 +916,7 @@ struct common_init_result common_init_from_params(common_params & params) { if (!params.control_vectors.empty()) { if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1; - if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model); + if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_model_n_layer(model); const auto cvec = common_control_vector_load(params.control_vectors); if (cvec.n_embd == -1) { @@ -908,12 +926,13 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } - int err = llama_control_vector_apply(lctx, - cvec.data.data(), - cvec.data.size(), - cvec.n_embd, - params.control_vector_layer_start, - params.control_vector_layer_end); + int err = llama_apply_adapter_cvec( + lctx, + cvec.data.data(), + cvec.data.size(), + cvec.n_embd, + params.control_vector_layer_start, + params.control_vector_layer_end); if (err) { llama_free(lctx); llama_model_free(model); @@ -924,8 +943,8 @@ struct common_init_result common_init_from_params(common_params & params) { // load and optionally apply lora adapters for (auto & la : params.lora_adapters) { - llama_lora_adapter_ptr lora; - lora.reset(llama_lora_adapter_init(model, la.path.c_str())); + llama_adapter_lora_ptr lora; + lora.reset(llama_adapter_lora_init(model, la.path.c_str())); if (lora == nullptr) { LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); llama_free(lctx); @@ -938,17 +957,17 @@ struct common_init_result common_init_from_params(common_params & params) { } if (!params.lora_init_without_apply) { - common_lora_adapters_apply(lctx, params.lora_adapters); + common_set_adapter_lora(lctx, params.lora_adapters); } - if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) { - LOG_WRN("%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__); + if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__); params.sampling.ignore_eos = false; } if (params.sampling.ignore_eos) { - for (llama_token i = 0; i < llama_n_vocab(model); i++) { - if (llama_token_is_eog(model, i)) { + for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) { + if (llama_vocab_is_eog(vocab, i)) { LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY); params.sampling.logit_bias.push_back({i, -INFINITY}); } @@ -969,8 +988,9 @@ struct common_init_result common_init_from_params(common_params & params) { LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__); std::vector tmp; - llama_token bos = llama_token_bos(model); - llama_token eos = llama_token_eos(model); + llama_token bos = llama_vocab_bos(vocab); + llama_token eos = llama_vocab_eos(vocab); + // some models (e.g. T5) don't have a BOS token if (bos != LLAMA_TOKEN_NULL) { tmp.push_back(bos); @@ -1005,11 +1025,11 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } -void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora) { - llama_lora_adapter_clear(ctx); +void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora) { + llama_clear_adapter_lora(ctx); for (auto & la : lora) { if (la.scale != 0.0f) { - llama_lora_adapter_set(ctx, la.ptr, la.scale); + llama_set_adapter_lora(ctx, la.ptr, la.scale); } } } @@ -1023,7 +1043,6 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { if (params.n_gpu_layers != -1) { mparams.n_gpu_layers = params.n_gpu_layers; } - mparams.rpc_servers = params.rpc_servers.c_str(); mparams.main_gpu = params.main_gpu; mparams.split_mode = params.split_mode; mparams.tensor_split = params.tensor_split; @@ -1126,7 +1145,8 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) { // Initialize libcurl - std::unique_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_slist_ptr http_headers; if (!curl) { LOG_ERR("%s: error initializing libcurl\n", __func__); return false; @@ -1140,11 +1160,9 @@ static bool common_download_file(const std::string & url, const std::string & pa // Check if hf-token or bearer-token was specified if (!hf_token.empty()) { - std::string auth_header = "Authorization: Bearer "; - auth_header += hf_token.c_str(); - struct curl_slist *http_headers = NULL; - http_headers = curl_slist_append(http_headers, auth_header.c_str()); - curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers); + std::string auth_header = "Authorization: Bearer " + hf_token; + http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); } #if defined(_WIN32) @@ -1440,6 +1458,80 @@ struct llama_model * common_load_model_from_hf( return common_load_model_from_url(model_url, local_path, hf_token, params); } +/** + * Allow getting the HF file from the HF repo with tag (like ollama), for example: + * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4 + * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M + * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s + * Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo) + * + * Return pair of (with "repo" already having tag removed) + * + * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files. + */ +std::pair common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & hf_token) { + auto parts = string_split(hf_repo_with_tag, ':'); + std::string tag = parts.size() > 1 ? parts.back() : "latest"; + std::string hf_repo = parts[0]; + if (string_split(hf_repo, '/').size() != 2) { + throw std::invalid_argument("error: invalid HF repo format, expected /[:quant]\n"); + } + + // fetch model info from Hugging Face Hub API + json model_info; + curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_slist_ptr http_headers; + std::string res_str; + std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag; + curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); + typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); + auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t { + static_cast(data)->append((char * ) ptr, size * nmemb); + return size * nmemb; + }; + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str); +#if defined(_WIN32) + curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); +#endif + if (!hf_token.empty()) { + std::string auth_header = "Authorization: Bearer " + hf_token; + http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); + } + // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response + http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); + http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json"); + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); + + CURLcode res = curl_easy_perform(curl.get()); + + if (res != CURLE_OK) { + throw std::runtime_error("error: cannot make GET request to HF API"); + } + + long res_code; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code); + if (res_code == 200) { + model_info = json::parse(res_str); + } else if (res_code == 401) { + throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token"); + } else { + throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str())); + } + + // check response + if (!model_info.contains("ggufFile")) { + throw std::runtime_error("error: model does not have ggufFile"); + } + json & gguf_file = model_info.at("ggufFile"); + if (!gguf_file.contains("rfilename")) { + throw std::runtime_error("error: ggufFile does not have rfilename"); + } + + return std::make_pair(hf_repo, gguf_file.at("rfilename")); +} + #else struct llama_model * common_load_model_from_url( @@ -1461,6 +1553,11 @@ struct llama_model * common_load_model_from_hf( return nullptr; } +std::pair common_get_hf_file(const std::string &, const std::string &) { + LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__); + return std::make_pair("", ""); +} + #endif // LLAMA_USE_CURL // @@ -1559,21 +1656,23 @@ std::vector common_tokenize( const std::string & text, bool add_special, bool parse_special) { - return common_tokenize(llama_get_model(ctx), text, add_special, parse_special); + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + return common_tokenize(vocab, text, add_special, parse_special); } std::vector common_tokenize( - const struct llama_model * model, + const struct llama_vocab * vocab, const std::string & text, bool add_special, bool parse_special) { // upper limit for the number of tokens int n_tokens = text.length() + 2 * add_special; std::vector result(n_tokens); - n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); GGML_ASSERT(check == -n_tokens); } else { result.resize(n_tokens); @@ -1582,12 +1681,18 @@ std::vector common_tokenize( } std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + return common_token_to_piece(vocab, token, special); +} + +std::string common_token_to_piece(const struct llama_vocab * vocab, llama_token token, bool special) { std::string piece; piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' - const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); + const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); if (n_chars < 0) { piece.resize(-n_chars); - int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); + int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); GGML_ASSERT(check == -n_chars); } else { @@ -1597,13 +1702,19 @@ std::string common_token_to_piece(const struct llama_context * ctx, llama_token return piece; } -std::string common_detokenize(llama_context * ctx, const std::vector & tokens, bool special) { +std::string common_detokenize(const struct llama_context * ctx, const std::vector & tokens, bool special) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + return common_detokenize(vocab, tokens, special); +} + +std::string common_detokenize(const struct llama_vocab * vocab, const std::vector & tokens, bool special) { std::string text; text.resize(std::max(text.capacity(), tokens.size())); - int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + int32_t n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); if (n_chars < 0) { text.resize(-n_chars); - n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization } @@ -1618,20 +1729,13 @@ std::string common_detokenize(llama_context * ctx, const std::vector 0) { - std::vector model_template(res + 1, 0); - llama_model_meta_val_str(model, template_key, model_template.data(), model_template.size()); - return std::string(model_template.data(), model_template.size() - 1); - } - return ""; + const char * ptr_tmpl = llama_model_chat_template(model); + return ptr_tmpl == nullptr ? "" : ptr_tmpl; } bool common_chat_verify_template(const std::string & tmpl) { llama_chat_message chat[] = {{"user", "test"}}; - int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0); + const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0); return res >= 0; } @@ -1642,16 +1746,16 @@ std::string common_chat_apply_template(const struct llama_model * model, int alloc_size = 0; bool fallback = false; // indicate if we must fallback to default chatml std::vector chat; - for (auto & msg : msgs) { + for (const auto & msg : msgs) { chat.push_back({msg.role.c_str(), msg.content.c_str()}); alloc_size += (msg.role.size() + msg.content.size()) * 1.25; } - const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); + const char * ptr_tmpl = tmpl.empty() ? llama_model_chat_template(model) : tmpl.c_str(); std::vector buf(alloc_size); // run the first time to get the total output length - int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + int32_t res = llama_chat_apply_template(ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); // error: chat template is not supported if (res < 0) { @@ -1659,18 +1763,17 @@ std::string common_chat_apply_template(const struct llama_model * model, // if the custom "tmpl" is not supported, we throw an error // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() throw std::runtime_error("this custom template is not supported"); - } else { - // If the built-in template is not supported, we default to chatml - res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size()); - fallback = true; } + + // If the built-in template is not supported, we default to chatml + res = llama_chat_apply_template("chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + fallback = true; } // if it turns out that our buffer is too small, we resize it if ((size_t) res > buf.size()) { buf.resize(res); res = llama_chat_apply_template( - fallback ? nullptr : model, fallback ? "chatml" : ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); } diff --git a/common/common.h b/common/common.h index 0d452cf0f5..3bcc637cc8 100644 --- a/common/common.h +++ b/common/common.h @@ -24,11 +24,11 @@ #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" -struct common_lora_adapter_info { +struct common_adapter_lora_info { std::string path; float scale; - struct llama_lora_adapter * ptr; + struct llama_adapter_lora * ptr; }; using llama_tokens = std::vector; @@ -103,6 +103,12 @@ enum dimre_method { DIMRE_METHOD_MEAN, }; +enum common_conversation_mode { + COMMON_CONVERSATION_MODE_DISABLED = 0, + COMMON_CONVERSATION_MODE_ENABLED = 1, + COMMON_CONVERSATION_MODE_AUTO = 2, +}; + // sampling parameters struct common_params_sampling { uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler @@ -178,6 +184,8 @@ struct common_params_vocoder { std::string model = ""; // model path // NOLINT std::string model_url = ""; // model url to download // NOLINT + + bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT }; struct common_params { @@ -240,14 +248,13 @@ struct common_params { std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT std::string logits_file = ""; // file for saving *all* logits // NOLINT - std::string rpc_servers = ""; // comma separated list of RPC servers // NOLINT std::vector in_files; // all input files std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector kv_overrides; - bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply) - std::vector lora_adapters; // lora adapter path with user defined scale + bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply) + std::vector lora_adapters; // lora adapter path with user defined scale std::vector control_vectors; // control vector with user defined scale @@ -275,7 +282,6 @@ struct common_params { bool special = false; // enable special token output bool interactive = false; // interactive mode bool interactive_first = false; // wait for user input immediately - bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix) bool prompt_cache_all = false; // save user input and generations to prompt cache bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it @@ -301,6 +307,8 @@ struct common_params { ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V + common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO; + // multimodal models (see examples/llava) std::string mmproj = ""; // path to multimodal projector // NOLINT std::vector image; // path to image file(s) @@ -454,6 +462,11 @@ static bool string_starts_with(const std::string & str, return str.rfind(prefix, 0) == 0; } +static bool string_ends_with(const std::string & str, + const std::string & suffix) { // While we wait for C++20's std::string::ends_with... + return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; +} + bool string_parse_kv_override(const char * data, std::vector & overrides); void string_process_escapes(std::string & input); @@ -481,7 +494,7 @@ struct common_init_result { llama_model_ptr model; llama_context_ptr context; - std::vector lora; + std::vector lora; }; struct common_init_result common_init_from_params(common_params & params); @@ -501,9 +514,12 @@ struct llama_model * common_load_model_from_hf( const std::string & local_path, const std::string & hf_token, const struct llama_model_params & params); +std::pair common_get_hf_file( + const std::string & hf_repo_with_tag, + const std::string & hf_token); // clear LoRA adapters from context, then apply new list of adapters -void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora); +void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora); // // Batch utils @@ -541,7 +557,7 @@ std::vector common_tokenize( bool parse_special = false); std::vector common_tokenize( - const struct llama_model * model, + const struct llama_vocab * vocab, const std::string & text, bool add_special, bool parse_special = false); @@ -553,11 +569,21 @@ std::string common_token_to_piece( llama_token token, bool special = true); +std::string common_token_to_piece( + const struct llama_vocab * vocab, + llama_token token, + bool special = true); + // detokenizes a vector of tokens into a string // should work similar to Python's `tokenizer.decode` // optionally renders special/control tokens std::string common_detokenize( - llama_context * ctx, + const struct llama_context * ctx, + const std::vector & tokens, + bool special = true); + +std::string common_detokenize( + const struct llama_vocab * vocab, const std::vector & tokens, bool special = true); diff --git a/common/sampling.cpp b/common/sampling.cpp index e83a971c74..7241ac3212 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -113,7 +113,10 @@ struct common_sampler { void set_logits(struct llama_context * ctx, int idx) { const auto * logits = llama_get_logits_ith(ctx, idx); - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_vocab = llama_vocab_n_tokens(vocab); cur.resize(n_vocab); @@ -142,13 +145,15 @@ std::string common_params_sampling::print() const { } struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { + const llama_vocab * vocab = llama_model_get_vocab(model); + llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); lparams.no_perf = params.no_perf; auto * result = new common_sampler { /* .params = */ params, - /* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"), + /* .grmr = */ llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"), /* .chain = */ llama_sampler_chain_init(lparams), /* .prev = */ ring_buffer(std::max(32, params.n_prev)), /* .cur = */ {}, @@ -157,7 +162,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co llama_sampler_chain_add(result->chain, llama_sampler_init_logit_bias( - llama_n_vocab(model), + llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data())); @@ -172,7 +177,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co c_breakers.push_back(str.c_str()); } - llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); + llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); } break; case COMMON_SAMPLER_TYPE_TOP_K: @@ -194,7 +199,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); break; case COMMON_SAMPLER_TYPE_INFILL: - llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model)); + llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab)); break; case COMMON_SAMPLER_TYPE_PENALTIES: llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); @@ -206,7 +211,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); } else if (params.mirostat == 1) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); + llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); } else if (params.mirostat == 2) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); diff --git a/common/speculative.cpp b/common/speculative.cpp index 3fcbb0020b..318e96ea35 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -79,10 +79,13 @@ bool common_speculative_are_compatible( const struct llama_model * model_tgt = llama_get_model(ctx_tgt); const struct llama_model * model_dft = llama_get_model(ctx_dft); - const bool vocab_type_tgt = llama_vocab_type(model_tgt); + const struct llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); + const struct llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); + + const bool vocab_type_tgt = llama_vocab_type(vocab_tgt); LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); - const bool vocab_type_dft = llama_vocab_type(model_dft); + const bool vocab_type_dft = llama_vocab_type(vocab_dft); LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft); if (vocab_type_tgt != vocab_type_dft) { @@ -91,34 +94,34 @@ bool common_speculative_are_compatible( return false; } - if (llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) || - llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) || - llama_token_bos(model_tgt) != llama_token_bos(model_dft) || - llama_token_eos(model_tgt) != llama_token_eos(model_dft)) { - LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__); - LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(model_tgt), llama_add_bos_token(model_tgt), llama_token_eos(model_tgt), llama_add_eos_token(model_tgt)); - LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(model_dft), llama_add_bos_token(model_dft), llama_token_eos(model_dft), llama_add_eos_token(model_dft)); + if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) || + llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) || + llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) || + llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)) { + LOG_ERR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__); + LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_tgt), llama_vocab_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_tgt)); + LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_dft), llama_vocab_get_add_bos(vocab_dft), llama_vocab_eos(vocab_dft), llama_vocab_get_add_eos(vocab_dft)); return false; } { - const int n_vocab_tgt = llama_n_vocab(model_tgt); - const int n_vocab_dft = llama_n_vocab(model_dft); + const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt); + const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft); const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft); if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { LOG_ERR("%s: draft model vocab must closely match target model to use speculation but " "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", - __func__, n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); + __func__, n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); return false; } for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { - const char * token_text_tgt = llama_token_get_text(model_tgt, i); - const char * token_text_dft = llama_token_get_text(model_dft, i); + const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); + const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); if (std::strcmp(token_text_tgt, token_text_dft) != 0) { - LOG_ERR("%s: draft model vocab must match target model to use speculation but " + LOG_ERR("%s: draft vocab vocab must match target vocab to use speculation but " "token %d content differs - target '%s', draft '%s'\n", __func__, i, common_token_to_piece(ctx_tgt, i).c_str(), common_token_to_piece(ctx_dft, i).c_str()); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index cf317eeae6..95f1120433 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -478,6 +478,11 @@ class Model: return modelcls return func + @classmethod + def print_registered_models(cls): + for name in sorted(cls._model_classes.keys()): + logger.error(f"- {name}") + @classmethod def from_model_architecture(cls, arch: str) -> type[Model]: try: @@ -2877,6 +2882,66 @@ class InternLM2Model(Model): return [(self.map_tensor_name(name), data_torch)] +@Model.register("InternLM3ForCausalLM") +class InternLM3Model(Model): + model_arch = gguf.MODEL_ARCH.LLAMA + + def set_vocab(self): + tokens, scores, toktypes = self._create_vocab_sentencepiece() + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + if "add_prefix_space" in tokenizer_config_json: + self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"]) + + if "added_tokens_decoder" in tokenizer_config_json: + for token_id, token_data in tokenizer_config_json["added_tokens_decoder"].items(): + if token_data.get("special"): + token_id = int(token_id) + token = token_data["content"] + special_vocab._set_special_token(token, token_id) + # update eos token + if token == '<|im_end|>' and "eos" in special_vocab.special_token_ids: + special_vocab.special_token_ids["eos"] = token_id + + special_vocab.add_to_gguf(self.gguf_writer) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + + if "head_dim" in hparams: + rope_dim = hparams["head_dim"] + else: + rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(rope_dim) + + if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: + if self.hparams["rope_scaling"].get("type") == "linear" or self.hparams["rope_scaling"].get("rope_type") == "linear": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + return [(self.map_tensor_name(name), data_torch)] + + @Model.register("BertModel", "BertForMaskedLM", "CamembertModel") class BertModel(Model): model_arch = gguf.MODEL_ARCH.BERT @@ -4929,6 +4994,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "model", type=Path, help="directory containing model file", + nargs="?", ) parser.add_argument( "--use-temp-file", action="store_true", @@ -4966,8 +5032,15 @@ def parse_args() -> argparse.Namespace: "--metadata", type=Path, help="Specify the path for an authorship metadata override file" ) + parser.add_argument( + "--print-supported-models", action="store_true", + help="Print the supported models" + ) - return parser.parse_args() + args = parser.parse_args() + if not args.print_supported_models and args.model is None: + parser.error("the following arguments are required: model") + return args def split_str_to_n_bytes(split_str: str) -> int: @@ -4991,6 +5064,11 @@ def split_str_to_n_bytes(split_str: str) -> int: def main() -> None: args = parse_args() + if args.print_supported_models: + logger.error("Supported models:") + Model.print_registered_models() + sys.exit(0) + if args.verbose: logging.basicConfig(level=logging.DEBUG) else: diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index dd75ff9f16..0659ab6f11 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -50,7 +50,7 @@ int main(int argc, char ** argv) { // ensure enough sequences are available ctx_params.n_seq_max = n_pl.empty() ? 1 : *std::max_element(n_pl.begin(), n_pl.end()); - llama_context * ctx = llama_new_context_with_model(model, ctx_params); + llama_context * ctx = llama_init_from_model(model, ctx_params); if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 10f2e7fd11..371917b2ee 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -23,12 +23,12 @@ defer { } let model_params = llama_model_default_params() -guard let model = llama_load_model_from_file(modelPath.cString(using: .utf8), model_params) else { +guard let model = llama_model_load_from_file(modelPath.cString(using: .utf8), model_params) else { print("Failed to load model") exit(1) } defer { - llama_free_model(model) + llama_model_free(model) } var tokens = tokenize(text: prompt, add_bos: true) @@ -141,7 +141,7 @@ while n_cur <= n_len { let new_token_id = llama_sampler_sample(smpl, context, i_batch[i]) // is it an end of stream? -> mark the stream as finished - if llama_token_is_eog(model, new_token_id) || n_cur == n_len { + if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len { i_batch[i] = -1 // print("") if n_parallel > 1 { diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index d34b030996..21b95ef5e4 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -48,10 +48,12 @@ int main(int argc, char ** argv) { return 1; } + const llama_vocab * vocab = llama_model_get_vocab(model); + // tokenize the prompt std::vector tokens_list; - tokens_list = common_tokenize(model, params.prompt, true); + tokens_list = common_tokenize(vocab, params.prompt, true); const int n_kv_req = tokens_list.size() + (n_predict - tokens_list.size())*n_parallel; @@ -62,7 +64,7 @@ int main(int argc, char ** argv) { ctx_params.n_ctx = n_kv_req; ctx_params.n_batch = std::max(n_predict, n_parallel); - llama_context * ctx = llama_new_context_with_model(model, ctx_params); + llama_context * ctx = llama_init_from_model(model, ctx_params); auto sparams = llama_sampler_chain_default_params(); sparams.no_perf = false; @@ -121,7 +123,7 @@ int main(int argc, char ** argv) { llama_token decoder_start_token_id = llama_model_decoder_start_token(model); if (decoder_start_token_id == LLAMA_TOKEN_NULL) { - decoder_start_token_id = llama_token_bos(model); + decoder_start_token_id = llama_vocab_bos(vocab); } common_batch_clear(batch); @@ -174,7 +176,7 @@ int main(int argc, char ** argv) { const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]); // is it an end of generation? -> mark the stream as finished - if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { + if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) { i_batch[i] = -1; LOG("\n"); if (n_parallel > 1) { diff --git a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp index 1256abb172..bdf0eed2a9 100644 --- a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +++ b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp @@ -911,7 +911,7 @@ int main(int argc, char ** argv) { load_vocab(params.fn_vocab_model, &config, &vocab); struct my_llama_model model; - model.hparams.n_vocab = config.vocab_size; //llama_n_vocab(lctx); + model.hparams.n_vocab = config.vocab_size; //llama_vocab_n_vocab(lctx); model.hparams.n_ctx = params.n_ctx; model.hparams.n_embd = config.dim; //params.n_embd; model.hparams.n_ff = config.hidden_dim; diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index e899c10781..413b71d34c 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -273,7 +273,9 @@ struct tokenized_prompt { size_t max_seq_len; tokenized_prompt(llama_context * ctx, std::string pos, std::string neg) { - const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + const bool add_bos = llama_vocab_get_add_bos(vocab); tokens_pos = common_tokenize(ctx, pos, add_bos, true); tokens_neg = common_tokenize(ctx, neg, add_bos, true); max_seq_len = std::max(tokens_pos.size(), tokens_neg.size()); @@ -421,8 +423,8 @@ int main(int argc, char ** argv) { llama_context * ctx = llama_init.context.get(); // int n_ctx = llama_n_ctx(ctx); - int n_layers = llama_n_layer(model); - int n_embd = llama_n_embd(model); + int n_layers = llama_model_n_layer(model); + int n_embd = llama_model_n_embd(model); // get model hint param (a.k.a model arch name) char model_hint[128]; diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 27f75cb770..38d22c90f8 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -105,7 +105,9 @@ int main(int argc, char ** argv) { return 1; } - const int n_ctx_train = llama_n_ctx_train(model); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_ctx_train = llama_model_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); @@ -148,7 +150,7 @@ int main(int argc, char ** argv) { // check if the last token is SEP // it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true' for (auto & inp : inputs) { - if (inp.empty() || inp.back() != llama_token_sep(model)) { + if (inp.empty() || inp.back() != llama_vocab_sep(vocab)) { LOG_WRN("%s: last token in the prompt is not SEP\n", __func__); LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__); } @@ -181,7 +183,7 @@ int main(int argc, char ** argv) { } // allocate output - const int n_embd = llama_n_embd(model); + const int n_embd = llama_model_n_embd(model); std::vector embeddings(n_embd_count * n_embd, 0); float * emb = embeddings.data(); diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 2111c3cda0..fb188f5a9e 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -127,7 +127,10 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { } static bool run(llama_context * ctx, const common_params & params) { - const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const bool add_bos = llama_vocab_get_add_bos(vocab); std::vector tokens = common_tokenize(ctx, params.prompt, add_bos); diff --git a/examples/export-lora/export-lora.cpp b/examples/export-lora/export-lora.cpp index d5dcd20a0a..99063b5d53 100644 --- a/examples/export-lora/export-lora.cpp +++ b/examples/export-lora/export-lora.cpp @@ -8,7 +8,6 @@ #include #include #include -#include #include static bool g_verbose = false; @@ -130,7 +129,7 @@ struct lora_merge_ctx { lora_merge_ctx( std::string & base_fname, - std::vector & lora_files, + std::vector & lora_files, std::string & outfile, int n_threads) : base_model(base_fname, 0), n_threads(n_threads), fout(outfile, std::ios::binary) { fout.exceptions(std::ofstream::failbit); // fail fast on write errors diff --git a/examples/gguf-split/tests.sh b/examples/gguf-split/tests.sh index d5a92d6051..05a9322271 100755 --- a/examples/gguf-split/tests.sh +++ b/examples/gguf-split/tests.sh @@ -41,7 +41,7 @@ echo PASS echo # 2b. Test the sharded model is loading properly -$MAIN --model $WORK_PATH/ggml-model-split-00001-of-00006.gguf --n-predict 32 +$MAIN -no-cnv --model $WORK_PATH/ggml-model-split-00001-of-00006.gguf --n-predict 32 echo PASS echo @@ -51,7 +51,7 @@ echo PASS echo # 3b. Test the merged model is loading properly -$MAIN --model $WORK_PATH/ggml-model-merge.gguf --n-predict 32 +$MAIN -no-cnv --model $WORK_PATH/ggml-model-merge.gguf --n-predict 32 echo PASS echo @@ -61,7 +61,7 @@ echo PASS echo # 4b. Test the sharded model is loading properly -$MAIN --model $WORK_PATH/ggml-model-split-32-tensors-00001-of-00007.gguf --n-predict 32 +$MAIN -no-cnv --model $WORK_PATH/ggml-model-split-32-tensors-00001-of-00007.gguf --n-predict 32 echo PASS echo @@ -71,7 +71,7 @@ echo #echo # 5b. Test the merged model is loading properly -#$MAIN --model $WORK_PATH/ggml-model-merge-2.gguf --n-predict 32 +#$MAIN -no-cnv --model $WORK_PATH/ggml-model-merge-2.gguf --n-predict 32 #echo PASS #echo @@ -81,7 +81,7 @@ echo PASS echo # 6b. Test the sharded model is loading properly -$MAIN --model $WORK_PATH/ggml-model-split-2G-00001-of-00002.gguf --n-predict 32 +$MAIN -no-cnv --model $WORK_PATH/ggml-model-split-2G-00001-of-00002.gguf --n-predict 32 echo PASS echo diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 4d2db56249..72eb462574 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -11,6 +11,7 @@ static std::vector> encode(llama_context * ctx, const std::ve std::vector> result; const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1); @@ -19,16 +20,16 @@ static std::vector> encode(llama_context * ctx, const std::ve const std::string input_string = instruction + sentences[i]; - std::vector inputs = common_tokenize(model, input_string, true, false); + std::vector inputs = common_tokenize(vocab, input_string, true, false); const int32_t n_toks = inputs.size(); // GritLM seems to have EOS = "" // https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18 - // inputs.push_back(llama_token_eos(model)); + // inputs.push_back(llama_vocab_eos(vocab)); // we want to ignore instruction tokens for mean pooling - const int32_t n_inst = common_tokenize(model, instruction, true, false).size(); + const int32_t n_inst = common_tokenize(vocab, instruction, true, false).size(); #ifdef GRIT_DEBUG // debug tokens - should be matching as referenced in the GritLM sample @@ -52,7 +53,7 @@ static std::vector> encode(llama_context * ctx, const std::ve llama_decode(ctx, batch); // get embedding dimensions - uint64_t n_embd = llama_n_embd(model); + uint64_t n_embd = llama_model_n_embd(model); // allocate embedding output std::vector emb_unorm(n_embd, 0.0f); @@ -97,7 +98,9 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std std::string result; const llama_model * model = llama_get_model(ctx); - llama_token eos_token = llama_token_eos(model); + const llama_vocab * vocab = llama_model_get_vocab(model); + + llama_token eos_token = llama_vocab_eos(vocab); llama_kv_cache_clear(ctx); llama_set_embeddings(ctx, false); @@ -105,7 +108,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); - std::vector inputs = common_tokenize(model, prompt, false, true); + std::vector inputs = common_tokenize(vocab, prompt, false, true); int32_t i_current_token = 0; while (true) { @@ -168,7 +171,7 @@ int main(int argc, char * argv[]) { llama_model * model = llama_model_load_from_file(params.model.c_str(), mparams); // create generation context - llama_context * ctx = llama_new_context_with_model(model, cparams); + llama_context * ctx = llama_init_from_model(model, cparams); auto sparams = llama_sampler_chain_default_params(); @@ -197,7 +200,7 @@ int main(int argc, char * argv[]) { const std::vector> d_rep = encode(ctx, documents, gritlm_instruction("")); const std::vector> q_rep = encode(ctx, queries, gritlm_instruction(instruction)); - const int n_embd = llama_n_embd(model); + const int n_embd = llama_model_n_embd(model); const float cosine_sim_q0_d0 = common_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd); const float cosine_sim_q0_d1 = common_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd); diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 588114ecd9..b5f3feb9f8 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -7,7 +7,6 @@ #include #include #include -#include #include #include #include @@ -40,7 +39,7 @@ public: void set_params(common_params params) { m_params = std::move(params); } bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data); void save_imatrix(int ncall = -1) const; - bool load_imatrix(const char * file_name); + bool load_imatrix(const char * fname); private: std::unordered_map m_stats; common_params m_params; @@ -429,10 +428,13 @@ static void process_logits( } static bool compute_imatrix(llama_context * ctx, const common_params & params) { - const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const bool add_bos = llama_vocab_get_add_bos(vocab); const int n_ctx = llama_n_ctx(ctx); - GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); + GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); auto tim1 = std::chrono::high_resolution_clock::now(); LOG_INF("%s: tokenizing the input ..\n", __func__); @@ -468,7 +470,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { const int n_chunk_max = tokens.size() / n_ctx; const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + const int n_vocab = llama_vocab_n_tokens(vocab); const int n_batch = params.n_batch; int count = 0; @@ -508,7 +510,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { // add BOS token for the first batch of each chunk if (add_bos && j == 0) { - tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); + tokens[batch_start] = llama_vocab_bos(vocab); } common_batch_clear(batch); @@ -627,7 +629,7 @@ int main(int argc, char ** argv) { return 1; } - const int n_ctx_train = llama_n_ctx_train(model); + const int n_ctx_train = llama_model_n_ctx_train(model); if (params.n_ctx > n_ctx_train) { LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, params.n_ctx); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index d460be314c..489a208b66 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -139,7 +139,9 @@ int main(int argc, char ** argv) { return 1; } - const int n_ctx_train = llama_n_ctx_train(model); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_ctx_train = llama_model_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); LOG_DBG("n_ctx: %d\n", n_ctx); @@ -152,28 +154,28 @@ int main(int argc, char ** argv) { LOG_INF("\n"); LOG_INF("%s\n", common_params_get_system_info(params).c_str()); } - const bool add_bos = llama_add_bos_token(model); - GGML_ASSERT(!llama_add_eos_token(model)); + const bool add_bos = llama_vocab_get_add_bos(vocab); + GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); std::vector embd_inp; std::vector embd_end; std::vector inp_pfx = common_tokenize(ctx, params.input_prefix, false); std::vector inp_sfx = common_tokenize(ctx, params.input_suffix, false); - GGML_ASSERT(llama_token_fim_pre(model) >= 0); - GGML_ASSERT(llama_token_fim_suf(model) >= 0); + GGML_ASSERT(llama_vocab_fim_pre(vocab) >= 0); + GGML_ASSERT(llama_vocab_fim_suf(vocab) >= 0); - inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(model)); - inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(model)); + inp_pfx.insert(inp_pfx.begin(), llama_vocab_fim_pre(vocab)); + inp_sfx.insert(inp_sfx.begin(), llama_vocab_fim_suf(vocab)); embd_inp = params.spm_infill ? inp_sfx : inp_pfx; embd_end = params.spm_infill ? inp_pfx : inp_sfx; if (add_bos) { - embd_inp.insert(embd_inp.begin(), llama_token_bos(model)); + embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); } embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); - const llama_token middle_token = llama_token_fim_mid(model); + const llama_token middle_token = llama_vocab_fim_mid(vocab); if (middle_token >= 0) { embd_inp.push_back(middle_token); } @@ -185,7 +187,7 @@ int main(int argc, char ** argv) { // Should not run without any tokens if (embd_inp.empty()) { - embd_inp.push_back(llama_token_bos(model)); + embd_inp.push_back(llama_vocab_bos(vocab)); LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str()); } @@ -420,10 +422,10 @@ int main(int argc, char ** argv) { // if not currently processing queued inputs; if ((int) embd_inp.size() <= n_consumed) { // deal with eot token in infill mode - if ((common_sampler_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){ + if ((common_sampler_last(smpl) == llama_vocab_eot(vocab) || is_interacting) && params.interactive){ if (is_interacting && !params.interactive_first) { // print an eot token - LOG("%s", common_token_to_piece(ctx, llama_token_eot(model)).c_str()); + LOG("%s", common_token_to_piece(ctx, llama_vocab_eot(vocab)).c_str()); } LOG("\n"); console::set_display(console::user_input); @@ -463,13 +465,13 @@ int main(int argc, char ** argv) { std::vector inp_pfx = common_tokenize(ctx, params.input_prefix, false); std::vector inp_sfx = common_tokenize(ctx, params.input_suffix, false); - inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(model)); - inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(model)); + inp_pfx.insert(inp_pfx.begin(), llama_vocab_fim_pre(vocab)); + inp_sfx.insert(inp_sfx.begin(), llama_vocab_fim_suf(vocab)); embd_inp = params.spm_infill ? inp_sfx : inp_pfx; embd_end = params.spm_infill ? inp_pfx : inp_sfx; if (add_bos) { - embd_inp.insert(embd_inp.begin(), llama_token_bos(model)); + embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); } embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); @@ -484,7 +486,7 @@ int main(int argc, char ** argv) { is_interacting = false; } // deal with end of generation tokens in interactive mode - else if (llama_token_is_eog(model, common_sampler_last(smpl))) { + else if (llama_vocab_is_eog(vocab, common_sampler_last(smpl))) { LOG_DBG("found EOS token\n"); if (params.interactive) { @@ -500,7 +502,7 @@ int main(int argc, char ** argv) { if (params.input_prefix_bos) { LOG_DBG("adding input prefix BOS token\n"); - embd_inp.push_back(llama_token_bos(model)); + embd_inp.push_back(llama_vocab_bos(vocab)); } std::string buffer; @@ -563,7 +565,7 @@ int main(int argc, char ** argv) { } // end of generation - if (!embd.empty() && llama_token_is_eog(model, embd.back()) && !params.interactive) { + if (!embd.empty() && llama_vocab_is_eog(vocab, embd.back()) && !params.interactive) { break; } @@ -575,7 +577,7 @@ int main(int argc, char ** argv) { } } if (!params.interactive && n_remain <= 0) { - LOG("%s", common_token_to_piece(ctx, llama_token_eot(model)).c_str()); + LOG("%s", common_token_to_piece(ctx, llama_vocab_eot(vocab)).c_str()); } LOG("\n"); diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 2a09167660..4ac19ca86e 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -683,7 +683,7 @@ struct cmd_params_instance { bool cpu_strict; int poll; int n_gpu_layers; - std::string rpc_servers; + std::string rpc_servers_str; llama_split_mode split_mode; int main_gpu; bool no_kv_offload; @@ -696,8 +696,37 @@ struct cmd_params_instance { llama_model_params mparams = llama_model_default_params(); mparams.n_gpu_layers = n_gpu_layers; - if (!rpc_servers.empty()) { - mparams.rpc_servers = rpc_servers.c_str(); + if (!rpc_servers_str.empty()) { + auto rpc_servers = string_split(rpc_servers_str, ','); + + // add RPC devices + if (!rpc_servers.empty()) { + ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC"); + if (!rpc_reg) { + fprintf(stderr, "%s: failed to find RPC backend\n", __func__); + exit(1); + } + + typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint); + ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device"); + if (!ggml_backend_rpc_add_device_fn) { + fprintf(stderr, "%s: failed to find RPC device add function\n", __func__); + exit(1); + } + static std::vector devices; + devices.clear(); + for (const std::string & server : rpc_servers) { + ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str()); + if (dev) { + devices.push_back(dev); + } else { + fprintf(stderr, "%s: failed to add RPC device for server '%s'\n", __func__, server.c_str()); + exit(1); + } + } + devices.push_back(nullptr); + mparams.devices = devices.data(); + } } mparams.split_mode = split_mode; mparams.main_gpu = main_gpu; @@ -708,7 +737,7 @@ struct cmd_params_instance { } bool equal_mparams(const cmd_params_instance & other) const { - return model == other.model && n_gpu_layers == other.n_gpu_layers && rpc_servers == other.rpc_servers && + return model == other.model && n_gpu_layers == other.n_gpu_layers && rpc_servers_str == other.rpc_servers_str && split_mode == other.split_mode && main_gpu == other.main_gpu && use_mmap == other.use_mmap && tensor_split == other.tensor_split; } @@ -1401,7 +1430,8 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th llama_set_n_threads(ctx, n_threads, n_threads); const llama_model * model = llama_get_model(ctx); - const int32_t n_vocab = llama_n_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); + const int32_t n_vocab = llama_vocab_n_tokens(vocab); std::vector tokens(n_batch); @@ -1409,7 +1439,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th while (n_processed < n_prompt) { int n_tokens = std::min(n_prompt - n_processed, n_batch); - tokens[0] = n_processed == 0 && llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab; + tokens[0] = n_processed == 0 && llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; for (int i = 1; i < n_tokens; i++) { tokens[i] = std::rand() % n_vocab; } @@ -1424,9 +1454,10 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) { llama_set_n_threads(ctx, n_threads, n_threads); const llama_model * model = llama_get_model(ctx); - const int32_t n_vocab = llama_n_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); + const int32_t n_vocab = llama_vocab_n_tokens(vocab); - llama_token token = llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab; + llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; for (int i = 0; i < n_gen; i++) { llama_decode(ctx, llama_batch_get_one(&token, 1)); @@ -1537,7 +1568,7 @@ int main(int argc, char ** argv) { prev_inst = &inst; } - llama_context * ctx = llama_new_context_with_model(lmodel, inst.to_llama_cparams()); + llama_context * ctx = llama_init_from_model(lmodel, inst.to_llama_cparams()); if (ctx == NULL) { fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, inst.model.c_str()); llama_model_free(lmodel); diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 66ec2aeeba..2a73983a98 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -87,7 +87,7 @@ Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring fi auto path_to_model = env->GetStringUTFChars(filename, 0); LOGi("Loading model from %s", path_to_model); - auto model = llama_load_model_from_file(path_to_model, model_params); + auto model = llama_model_load_from_file(path_to_model, model_params); env->ReleaseStringUTFChars(filename, path_to_model); if (!model) { @@ -102,7 +102,7 @@ Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring fi extern "C" JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) { - llama_free_model(reinterpret_cast(model)); + llama_model_free(reinterpret_cast(model)); } extern "C" @@ -347,6 +347,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( jlong context_pointer, jlong batch_pointer, jstring jtext, + jboolean format_chat, jint n_len ) { @@ -356,7 +357,8 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( const auto context = reinterpret_cast(context_pointer); const auto batch = reinterpret_cast(batch_pointer); - const auto tokens_list = common_tokenize(context, text, 1); + bool parse_special = (format_chat == JNI_TRUE); + const auto tokens_list = common_tokenize(context, text, true, parse_special); auto n_ctx = llama_n_ctx(context); auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size()); @@ -368,7 +370,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( } for (auto id : tokens_list) { - LOGi("%s", common_token_to_piece(context, id).c_str()); + LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id); } common_batch_clear(*batch); @@ -405,6 +407,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( const auto batch = reinterpret_cast(batch_pointer); const auto sampler = reinterpret_cast(sampler_pointer); const auto model = llama_get_model(context); + const auto vocab = llama_model_get_vocab(model); if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur); if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I"); @@ -414,7 +417,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( const auto new_token_id = llama_sampler_sample(sampler, context, -1); const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); - if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { + if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) { return nullptr; } diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt index cf520e4594..b964d93e37 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt @@ -65,6 +65,7 @@ class LLamaAndroid { context: Long, batch: Long, text: String, + formatChat: Boolean, nLen: Int ): Int @@ -115,10 +116,10 @@ class LLamaAndroid { } } - fun send(message: String): Flow = flow { + fun send(message: String, formatChat: Boolean = false): Flow = flow { when (val state = threadLocalState.get()) { is State.Loaded -> { - val ncur = IntVar(completion_init(state.context, state.batch, message, nlen)) + val ncur = IntVar(completion_init(state.context, state.batch, message, formatChat, nlen)) while (ncur.value <= nlen) { val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur) if (str == null) { diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 998c673d5d..477c3e6f2e 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -52,8 +52,8 @@ actor LlamaContext { deinit { llama_sampler_free(sampling) llama_batch_free(batch) + llama_model_free(model) llama_free(context) - llama_free_model(model) llama_backend_free() } @@ -65,7 +65,7 @@ actor LlamaContext { model_params.n_gpu_layers = 0 print("Running on simulator, force use n_gpu_layers = 0") #endif - let model = llama_load_model_from_file(path, model_params) + let model = llama_model_load_from_file(path, model_params) guard let model else { print("Could not load model at \(path)") throw LlamaError.couldNotInitializeContext @@ -151,7 +151,7 @@ actor LlamaContext { new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1) - if llama_token_is_eog(model, new_token_id) || n_cur == n_len { + if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len { print("\n") is_done = true let new_token_str = String(cString: temporary_invalid_cchars + [0]) diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 27215a42e8..40aa0876f2 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -47,8 +47,12 @@ static const char * sample(struct common_sampler * smpl, int * n_past) { const llama_token id = common_sampler_sample(smpl, ctx_llama, -1); common_sampler_accept(smpl, id, true); + + const llama_model * model = llama_get_model(ctx_llama); + const llama_vocab * vocab = llama_model_get_vocab(model); + static std::string ret; - if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { + if (llama_vocab_is_eog(vocab, id)) { ret = ""; } else { ret = common_token_to_piece(ctx_llama, id); @@ -239,11 +243,10 @@ static struct llava_context * llava_init_context(common_params * params, llama_m auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); - llama_context_params ctx_params = common_context_params_to_llama(*params); ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings - llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params); + llama_context * ctx_llama = llama_init_from_model(model, ctx_params); if (ctx_llama == NULL) { LOG_ERR("%s: failed to create the llama_context\n" , __func__); diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 16f30c56ce..c598caf3dd 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -384,7 +384,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) { // make sure that the correct mmproj was used, i.e., compare apples to apples - int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama)); + int n_llama_embd = llama_model_n_embd(llama_get_model(ctx_llama)); auto n_image_embd = clip_n_mmproj_embd(ctx_clip); if (n_image_embd != n_llama_embd) { LOG_ERR("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd); @@ -456,7 +456,7 @@ struct llava_embd_batch { }; bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) { - int n_embd = llama_n_embd(llama_get_model(ctx_llama)); + int n_embd = llama_model_n_embd(llama_get_model(ctx_llama)); for (int i = 0; i < image_embed->n_image_pos; i += n_batch) { int n_eval = image_embed->n_image_pos - i; diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index 2342bdd095..38c44e130e 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -54,7 +54,7 @@ static struct llava_context * llava_init_context(common_params * params, llama_m ctx_params.n_ctx = params->n_ctx; } - llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params); + llama_context * ctx_llama = llama_init_from_model(model, ctx_params); if (ctx_llama == NULL) { LOG_ERR("%s: failed to create the llama_context\n" , __func__); @@ -167,8 +167,12 @@ static const char * sample(struct common_sampler * smpl, int * n_past) { const llama_token id = common_sampler_sample(smpl, ctx_llama, -1); common_sampler_accept(smpl, id, true); + + const llama_model * model = llama_get_model(ctx_llama); + const llama_vocab * vocab = llama_model_get_vocab(model); + static std::string ret; - if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { + if (llama_vocab_is_eog(vocab, id)) { ret = ""; } else { ret = common_token_to_piece(ctx_llama, id); diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index f3e5d66e2c..132a7da543 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -27,7 +27,7 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past, int * st_pos_id, struct clip_image_size * image_size) { - int n_embd = llama_n_embd(llama_get_model(ctx_llama)); + int n_embd = llama_model_n_embd(llama_get_model(ctx_llama)); const int patch_size = 14 * 2; const int ph = image_size->height / patch_size + (image_size->height % patch_size > 0); const int pw = image_size->width / patch_size + (image_size->width % patch_size > 0); @@ -132,8 +132,12 @@ static const char * sample(struct common_sampler * smpl, int * n_past, int * st_pos_id) { const llama_token id = common_sampler_sample(smpl, ctx_llama, -1); common_sampler_accept(smpl, id, true); + + const llama_model * model = llama_get_model(ctx_llama); + const llama_vocab * vocab = llama_model_get_vocab(model); + static std::string ret; - if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { + if (llama_vocab_is_eog(vocab, id)) { ret = ""; } else { ret = common_token_to_piece(ctx_llama, id); @@ -328,11 +332,10 @@ static struct llava_context * llava_init_context(common_params * params, llama_m auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); - llama_context_params ctx_params = common_context_params_to_llama(*params); ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings - llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params); + llama_context * ctx_llama = llama_init_from_model(model, ctx_params); if (ctx_llama == NULL) { LOG_ERR("%s: failed to create the llama_context\n" , __func__); @@ -481,7 +484,7 @@ static void debug_test_mrope_2d() { } static void debug_dump_img_embed(struct llava_context * ctx_llava) { - int n_embd = llama_n_embd(llama_get_model(ctx_llava->ctx_llama)); + int n_embd = llama_model_n_embd(llama_get_model(ctx_llava->ctx_llama)); int ne = n_embd * 4; float vals[56 * 56 * 3]; // float embd[ne]; diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index e016618e35..2f0898e628 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -61,6 +61,8 @@ int main(int argc, char ** argv) { llama_model * model = llama_init.model.get(); llama_context * ctx = llama_init.context.get(); + const llama_vocab * vocab = llama_model_get_vocab(model); + // Tokenize the prompt std::vector inp; std::vector all; @@ -147,7 +149,7 @@ int main(int argc, char ** argv) { } // here we keep adding new n-grams as we go - ngram_container ngrams_observed(llama_n_vocab(model), N, G); + ngram_container ngrams_observed(llama_vocab_n_tokens(vocab), N, G); // debug struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, W + G + 1); @@ -297,7 +299,7 @@ int main(int argc, char ** argv) { } fflush(stdout); - if (llama_token_is_eog(model, id)) { + if (llama_vocab_is_eog(vocab, id)) { has_eos = true; } diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 0d68b80b9a..dbd0444ec8 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -36,6 +36,8 @@ int main(int argc, char ** argv){ llama_model * model = llama_init.model.get(); llama_context * ctx = llama_init.context.get(); + const llama_vocab * vocab = llama_model_get_vocab(model); + // tokenize the prompt std::vector inp; inp = common_tokenize(ctx, params.prompt, true, true); @@ -136,7 +138,7 @@ int main(int argc, char ** argv){ LOG("%s", token_str.c_str()); } - if (llama_token_is_eog(model, id)) { + if (llama_vocab_is_eog(vocab, id)) { has_eos = true; } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index aaee47e325..39666a0e8a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -5,7 +5,6 @@ #include "sampling.h" #include "llama.h" -#include #include #include #include @@ -31,6 +30,8 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif +static const char * DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant"; + static llama_context ** g_ctx; static llama_model ** g_model; static common_sampler ** g_smpl; @@ -163,6 +164,8 @@ int main(int argc, char ** argv) { return 1; } + const llama_vocab * vocab = llama_model_get_vocab(model); + LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads); auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU)); @@ -196,15 +199,31 @@ int main(int argc, char ** argv) { llama_attach_threadpool(ctx, threadpool, threadpool_batch); - const int n_ctx_train = llama_n_ctx_train(model); + const int n_ctx_train = llama_model_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); if (n_ctx > n_ctx_train) { LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx); } + // auto enable conversation mode if chat template is available + const bool has_chat_template = !common_get_builtin_chat_template(model).empty() || !params.chat_template.empty(); + if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) { + if (has_chat_template) { + LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__); + params.conversation_mode = COMMON_CONVERSATION_MODE_ENABLED; + } else { + params.conversation_mode = COMMON_CONVERSATION_MODE_DISABLED; + } + } + + // in case user force-activate conversation mode (via -cnv) without proper chat template, we show a warning + if (params.conversation_mode && !has_chat_template) { + LOG_WRN("%s: chat template is not available or is not supported. This may cause the model to output suboptimal responses\n", __func__); + } + // print chat template example in conversation mode - if (params.conversation) { + if (params.conversation_mode) { if (params.enable_chat_template) { LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template).c_str()); } else { @@ -241,9 +260,9 @@ int main(int argc, char ** argv) { } } - const bool add_bos = llama_add_bos_token(model); + const bool add_bos = llama_vocab_get_add_bos(vocab); if (!llama_model_has_encoder(model)) { - GGML_ASSERT(!llama_add_eos_token(model)); + GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); } LOG_DBG("n_ctx: %d, add_bos: %d\n", n_ctx, add_bos); @@ -251,8 +270,10 @@ int main(int argc, char ** argv) { std::vector embd_inp; { - auto prompt = (params.conversation && params.enable_chat_template && !params.prompt.empty()) - ? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode + auto prompt = (params.conversation_mode && params.enable_chat_template) + // format the system prompt in conversation mode (fallback to default if empty) + ? chat_add_and_format(model, chat_msgs, "system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt) + // otherwise use the prompt as is : params.prompt; if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { LOG_DBG("tokenize the prompt\n"); @@ -269,7 +290,7 @@ int main(int argc, char ** argv) { // Should not run without any tokens if (embd_inp.empty()) { if (add_bos) { - embd_inp.push_back(llama_token_bos(model)); + embd_inp.push_back(llama_vocab_bos(vocab)); LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str()); } else { LOG_ERR("input is empty\n"); @@ -326,7 +347,7 @@ int main(int argc, char ** argv) { params.n_keep += add_bos; // always keep the BOS token } - if (params.conversation) { + if (params.conversation_mode) { params.interactive_first = true; } @@ -450,7 +471,11 @@ int main(int argc, char ** argv) { #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) LOG_INF( " - Press Ctrl+C to interject at any time.\n"); #endif - LOG_INF( "%s\n", control_message); + LOG_INF( "%s", control_message); + if (params.conversation_mode && params.enable_chat_template && params.prompt.empty()) { + LOG_INF( " - Using default system message. To change it, set a different value via -p PROMPT or -f FILE argument.\n"); + } + LOG_INF("\n"); is_interacting = params.interactive_first; } @@ -495,7 +520,7 @@ int main(int argc, char ** argv) { llama_token decoder_start_token_id = llama_model_decoder_start_token(model); if (decoder_start_token_id == LLAMA_TOKEN_NULL) { - decoder_start_token_id = llama_token_bos(model); + decoder_start_token_id = llama_vocab_bos(vocab); } embd_inp.clear(); @@ -742,7 +767,7 @@ int main(int argc, char ** argv) { } // deal with end of generation tokens in interactive mode - if (llama_token_is_eog(model, common_sampler_last(smpl))) { + if (llama_vocab_is_eog(vocab, common_sampler_last(smpl))) { LOG_DBG("found an EOG token\n"); if (params.interactive) { @@ -762,7 +787,7 @@ int main(int argc, char ** argv) { } // if current token is not EOG, we add it to current assistant message - if (params.conversation) { + if (params.conversation_mode) { const auto id = common_sampler_last(smpl); assistant_ss << common_token_to_piece(ctx, id, false); } @@ -770,17 +795,17 @@ int main(int argc, char ** argv) { if (n_past > 0 && is_interacting) { LOG_DBG("waiting for user input\n"); - if (params.conversation) { + if (params.conversation_mode) { LOG("\n> "); } if (params.input_prefix_bos) { LOG_DBG("adding input prefix BOS token\n"); - embd_inp.push_back(llama_token_bos(model)); + embd_inp.push_back(llama_vocab_bos(vocab)); } std::string buffer; - if (!params.input_prefix.empty() && !params.conversation) { + if (!params.input_prefix.empty() && !params.conversation_mode) { LOG_DBG("appending input prefix: '%s'\n", params.input_prefix.c_str()); LOG("%s", params.input_prefix.c_str()); } @@ -804,7 +829,7 @@ int main(int argc, char ** argv) { // Entering a empty line lets the user pass control back if (buffer.length() > 1) { // append input suffix if any - if (!params.input_suffix.empty() && !params.conversation) { + if (!params.input_suffix.empty() && !params.conversation_mode) { LOG_DBG("appending input suffix: '%s'\n", params.input_suffix.c_str()); LOG("%s", params.input_suffix.c_str()); } @@ -817,7 +842,7 @@ int main(int argc, char ** argv) { string_process_escapes(buffer); } - bool format_chat = params.conversation && params.enable_chat_template; + bool format_chat = params.conversation_mode && params.enable_chat_template; std::string user_inp = format_chat ? chat_add_and_format(model, chat_msgs, "user", std::move(buffer)) : std::move(buffer); @@ -830,8 +855,8 @@ int main(int argc, char ** argv) { // if user stop generation mid-way, we must add EOT to finish model's last response if (need_insert_eot && format_chat) { - llama_token eot = llama_token_eot(model); - embd_inp.push_back(eot == LLAMA_TOKEN_NULL ? llama_token_eos(model) : eot); + llama_token eot = llama_vocab_eot(vocab); + embd_inp.push_back(eot == LLAMA_TOKEN_NULL ? llama_vocab_eos(vocab) : eot); need_insert_eot = false; } @@ -866,7 +891,7 @@ int main(int argc, char ** argv) { } // end of generation - if (!embd.empty() && llama_token_is_eog(model, embd.back()) && !(params.interactive)) { + if (!embd.empty() && llama_vocab_is_eog(vocab, embd.back()) && !(params.interactive)) { LOG(" [end of text]\n"); break; } diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index d48f519757..7ef43d5e12 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -135,6 +135,8 @@ int main(int argc, char ** argv) { llama_model * model = llama_init.model.get(); llama_context * ctx = llama_init.context.get(); + const llama_vocab * vocab = llama_model_get_vocab(model); + // load the prompts from an external file if there are any if (params.prompt.empty()) { LOG_INF("\033[32mNo new questions so proceed with build-in defaults.\033[0m\n"); @@ -358,7 +360,7 @@ int main(int argc, char ** argv) { // client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str()); if (client.n_decoded > 2 && - (llama_token_is_eog(model, id) || + (llama_vocab_is_eog(vocab, id) || (params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) || client.response.find("User:") != std::string::npos || client.response.find('\n') != std::string::npos)) { diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index ea91f376cd..5953928d47 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -70,15 +70,17 @@ int main(int argc, char ** argv) { return 1; } + const llama_vocab * vocab = llama_model_get_vocab(model); + // initialize the context llama_context_params ctx_params = common_context_params_to_llama(params); - ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep; + ctx_params.n_ctx = llama_model_n_ctx_train(model)*n_grp + n_keep; GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp"); - llama_context * ctx = llama_new_context_with_model(model, ctx_params); + llama_context * ctx = llama_init_from_model(model, ctx_params); if (ctx == NULL) { LOG_ERR("%s: failed to create the llama_context\n" , __func__); return 1; @@ -223,7 +225,7 @@ int main(int argc, char ** argv) { const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); // is it an end of generation? - if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { + if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) { LOG("\n"); break; diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 6bdc57f8e0..9bf6c57433 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -296,8 +296,11 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params // Output: `perplexity: 13.5106 [114/114]` // BOS tokens will be added for each chunk before eval - const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); - GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const bool add_bos = llama_vocab_get_add_bos(vocab); + GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); LOG_INF("%s: tokenizing the input ..\n", __func__); @@ -338,7 +341,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); const int n_batch = params.n_batch; - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + const int n_vocab = llama_vocab_n_tokens(vocab); int count = 0; double nll = 0.0; @@ -382,7 +385,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params // add BOS token for the first batch of each chunk if (add_bos && j == 0) { - tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); + tokens[batch_start] = llama_vocab_bos(vocab); } const auto * batch_logits = llama_get_logits(ctx); @@ -444,8 +447,11 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & // Output: `perplexity: 13.5106 [114/114]` // BOS tokens will be added for each chunk before eval - const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); - GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const bool add_bos = llama_vocab_get_add_bos(vocab); + GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); std::ofstream logits_stream; if (!params.logits_file.empty()) { @@ -485,7 +491,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); const int n_batch = params.n_batch; - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + const int n_vocab = llama_vocab_n_tokens(vocab); int count = 0; double nll = 0.0; @@ -557,7 +563,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & // add BOS token for the first batch of each chunk if (add_bos && j == 0) { - tokens[seq_start] = llama_token_bos(llama_get_model(ctx)); + tokens[seq_start] = llama_vocab_bos(vocab); } for (int k = 0; k < batch_size; ++k) { @@ -732,6 +738,9 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto } static void hellaswag_score(llama_context * ctx, const common_params & params) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + // Calculates hellaswag score (acc_norm) from prompt // // Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl @@ -765,7 +774,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { size_t hs_task_count = prompt_lines.size()/6; LOG_INF("%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count); - const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM; + const bool is_spm = llama_vocab_type(vocab) == LLAMA_VOCAB_TYPE_SPM; LOG_INF("================================= is_spm = %d\n", is_spm); // The tasks should be randomized so the score stabilizes quickly. @@ -848,7 +857,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { const int n_ctx = llama_n_ctx(ctx); const int n_batch = params.n_batch; - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + const int n_vocab = llama_vocab_n_tokens(vocab); const int max_tasks_per_batch = 32; const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); @@ -1072,6 +1081,8 @@ static std::vector load_winogrande_from_csv(const std::string * */ static void winogrande_score(llama_context * ctx, const common_params & params) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); constexpr int k_min_trailing_ctx = 3; @@ -1130,7 +1141,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) const int n_ctx = llama_n_ctx(ctx); const int n_batch = params.n_batch; - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + const int n_vocab = llama_vocab_n_tokens(vocab); const int max_tasks_per_batch = 128; const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); @@ -1374,6 +1385,8 @@ static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choic // https://huggingface.co/datasets/truthful_qa // static void multiple_choice_score(llama_context * ctx, const common_params & params) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); std::istringstream strstream(params.prompt); uint32_t n_task; @@ -1482,7 +1495,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par const int n_ctx = llama_n_ctx(ctx); const int n_batch = params.n_batch; - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + const int n_vocab = llama_vocab_n_tokens(vocab); const int max_tasks_per_batch = 32; const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); @@ -1655,6 +1668,9 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par } static void kl_divergence(llama_context * ctx, const common_params & params) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + if (params.logits_file.empty()) { LOG_ERR("%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__); return; @@ -1688,8 +1704,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { LOG_ERR("%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str()); return; } - if (n_vocab != llama_n_vocab(llama_get_model(ctx))) { - LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(llama_get_model(ctx))); + if (n_vocab != llama_vocab_n_tokens(vocab)) { + LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_vocab_n_tokens(vocab)); } std::vector tokens(size_t(n_ctx) * n_chunk); @@ -1701,8 +1717,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { const int n_batch = params.n_batch; const int num_batches = (n_ctx + n_batch - 1)/n_batch; const int nv = 2*((n_vocab + 1)/2) + 4; - const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); - GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); + const bool add_bos = llama_vocab_get_add_bos(vocab); + GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); std::vector log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv); std::vector kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); @@ -1761,7 +1777,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { // add BOS token for the first batch of each chunk if (add_bos && j == 0) { - tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); + tokens[batch_start] = llama_vocab_bos(vocab); } common_batch_clear(batch); @@ -1995,7 +2011,7 @@ int main(int argc, char ** argv) { return 1; } - const int n_ctx_train = llama_n_ctx_train(model); + const int n_ctx_train = llama_model_n_ctx_train(model); if (params.n_ctx > n_ctx_train) { LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 9bfbb88623..bd2f734670 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -319,7 +319,7 @@ int main(int argc, char ** argv) { auto cparams = llama_context_default_params(); cparams.n_ctx = 256; - ctx = llama_new_context_with_model(model, cparams); + ctx = llama_init_from_model(model, cparams); if (ctx == NULL) { fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); diff --git a/examples/quantize/tests.sh b/examples/quantize/tests.sh index 24bc970e86..70f7610f98 100644 --- a/examples/quantize/tests.sh +++ b/examples/quantize/tests.sh @@ -47,7 +47,7 @@ echo PASS echo # 3a. Test the requanted model is loading properly -$MAIN --model $WORK_PATH/ggml-model-requant-00001-of-00006.gguf --n-predict 32 +$MAIN -no-cnv --model $WORK_PATH/ggml-model-requant-00001-of-00006.gguf --n-predict 32 echo PASS echo @@ -57,7 +57,7 @@ echo PASS echo # 4b. Test the requanted model is loading properly -$MAIN --model $WORK_PATH/ggml-model-requant-merge.gguf --n-predict 32 +$MAIN -no-cnv --model $WORK_PATH/ggml-model-requant-merge.gguf --n-predict 32 echo PASS echo diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index f534b5effd..2439022a22 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -159,7 +159,9 @@ int main(int argc, char ** argv) { return 1; } - const int n_ctx_train = llama_n_ctx_train(model); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_ctx_train = llama_model_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); @@ -192,8 +194,8 @@ int main(int argc, char ** argv) { return 1; } // add eos if not present - if (llama_token_eos(model) >= 0 && (inp.empty() || inp.back() != llama_token_eos(model))) { - inp.push_back(llama_token_eos(model)); + if (llama_vocab_eos(vocab) >= 0 && (inp.empty() || inp.back() != llama_vocab_eos(vocab))) { + inp.push_back(llama_vocab_eos(vocab)); } chunk.tokens = inp; } @@ -215,7 +217,7 @@ int main(int argc, char ** argv) { struct llama_batch batch = llama_batch_init(n_batch, 0, 1); // allocate output - const int n_embd = llama_n_embd(model); + const int n_embd = llama_model_n_embd(model); std::vector embeddings(n_chunks * n_embd, 0); float * emb = embeddings.data(); diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 61420e441e..0ad8bb15b2 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -29,7 +29,7 @@ #if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32) [[noreturn]] static void sigint_handler(int) { - printf("\n"); + printf("\n\033[0m"); exit(0); // not ideal, but it's the only way to guarantee exit in all cases } #endif @@ -685,7 +685,7 @@ class LlamaData { // Initializes the context with the specified parameters llama_context_ptr initialize_context(const llama_model_ptr & model, const Opt & opt) { - llama_context_ptr context(llama_new_context_with_model(model.get(), opt.ctx_params)); + llama_context_ptr context(llama_init_from_model(model.get(), opt.ctx_params)); if (!context) { printe("%s: error: failed to create the llama_context\n", __func__); } @@ -713,11 +713,11 @@ static void add_message(const char * role, const std::string & text, LlamaData & // Function to apply the chat template and resize `formatted` if needed static int apply_chat_template(LlamaData & llama_data, const bool append) { int result = llama_chat_apply_template( - llama_data.model.get(), nullptr, llama_data.messages.data(), llama_data.messages.size(), append, + llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append, append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0); if (append && result > static_cast(llama_data.fmtted.size())) { llama_data.fmtted.resize(result); - result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(), + result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append, llama_data.fmtted.data(), llama_data.fmtted.size()); } @@ -726,11 +726,11 @@ static int apply_chat_template(LlamaData & llama_data, const bool append) { } // Function to tokenize the prompt -static int tokenize_prompt(const llama_model_ptr & model, const std::string & prompt, +static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt, std::vector & prompt_tokens) { - const int n_prompt_tokens = -llama_tokenize(model.get(), prompt.c_str(), prompt.size(), NULL, 0, true, true); + const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true); prompt_tokens.resize(n_prompt_tokens); - if (llama_tokenize(model.get(), prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, + if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) { printe("failed to tokenize the prompt\n"); return -1; @@ -753,9 +753,9 @@ static int check_context_size(const llama_context_ptr & ctx, const llama_batch & } // convert the token to a string -static int convert_token_to_string(const llama_model_ptr & model, const llama_token token_id, std::string & piece) { +static int convert_token_to_string(const llama_vocab * vocab, const llama_token token_id, std::string & piece) { char buf[256]; - int n = llama_token_to_piece(model.get(), token_id, buf, sizeof(buf), 0, true); + int n = llama_token_to_piece(vocab, token_id, buf, sizeof(buf), 0, true); if (n < 0) { printe("failed to convert token to piece\n"); return 1; @@ -773,8 +773,10 @@ static void print_word_and_concatenate_to_response(const std::string & piece, st // helper function to evaluate a prompt and generate a response static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) { + const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get()); + std::vector tokens; - if (tokenize_prompt(llama_data.model, prompt, tokens) < 0) { + if (tokenize_prompt(vocab, prompt, tokens) < 0) { return 1; } @@ -790,12 +792,12 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str // sample the next token, check is it an end of generation? new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1); - if (llama_token_is_eog(llama_data.model.get(), new_token_id)) { + if (llama_vocab_is_eog(vocab, new_token_id)) { break; } std::string piece; - if (convert_token_to_string(llama_data.model, new_token_id, piece)) { + if (convert_token_to_string(vocab, new_token_id, piece)) { return 1; } diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index cd03661cf6..cf7cbd8159 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -97,7 +97,7 @@ int main(int argc, char ** argv) { printf("\n\n"); // make new context - llama_context * ctx2 = llama_new_context_with_model(model, common_context_params_to_llama(params)); + llama_context * ctx2 = llama_init_from_model(model, common_context_params_to_llama(params)); llama_sampler * smpl2 = llama_sampler_chain_init(sparams); @@ -154,7 +154,7 @@ int main(int argc, char ** argv) { } // make new context - llama_context * ctx3 = llama_new_context_with_model(model, common_context_params_to_llama(params)); + llama_context * ctx3 = llama_init_from_model(model, common_context_params_to_llama(params)); llama_sampler * smpl3 = llama_sampler_chain_init(sparams); diff --git a/examples/server/httplib.h b/examples/server/httplib.h index f360bd93ea..c2f12dd2ac 100644 --- a/examples/server/httplib.h +++ b/examples/server/httplib.h @@ -8,7 +8,7 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.15.3" +#define CPPHTTPLIB_VERSION "0.18.5" /* * Configuration @@ -18,8 +18,12 @@ #define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 #endif +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND 10000 +#endif + #ifndef CPPHTTPLIB_KEEPALIVE_MAX_COUNT -#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 5 +#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 100 #endif #ifndef CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND @@ -30,20 +34,36 @@ #define CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND 0 #endif -#ifndef CPPHTTPLIB_READ_TIMEOUT_SECOND -#define CPPHTTPLIB_READ_TIMEOUT_SECOND 5 +#ifndef CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND 5 #endif -#ifndef CPPHTTPLIB_READ_TIMEOUT_USECOND -#define CPPHTTPLIB_READ_TIMEOUT_USECOND 0 +#ifndef CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND 0 #endif -#ifndef CPPHTTPLIB_WRITE_TIMEOUT_SECOND -#define CPPHTTPLIB_WRITE_TIMEOUT_SECOND 5 +#ifndef CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND +#define CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND 5 #endif -#ifndef CPPHTTPLIB_WRITE_TIMEOUT_USECOND -#define CPPHTTPLIB_WRITE_TIMEOUT_USECOND 0 +#ifndef CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND +#define CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND 300 +#endif + +#ifndef CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND +#define CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND +#define CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND 0 #endif #ifndef CPPHTTPLIB_IDLE_INTERVAL_SECOND @@ -90,8 +110,12 @@ #define CPPHTTPLIB_TCP_NODELAY false #endif +#ifndef CPPHTTPLIB_IPV6_V6ONLY +#define CPPHTTPLIB_IPV6_V6ONLY false +#endif + #ifndef CPPHTTPLIB_RECV_BUFSIZ -#define CPPHTTPLIB_RECV_BUFSIZ size_t(4096u) +#define CPPHTTPLIB_RECV_BUFSIZ size_t(16384u) #endif #ifndef CPPHTTPLIB_COMPRESSION_BUFSIZ @@ -145,11 +169,11 @@ using ssize_t = long; #endif // _MSC_VER #ifndef S_ISREG -#define S_ISREG(m) (((m)&S_IFREG) == S_IFREG) +#define S_ISREG(m) (((m) & S_IFREG) == S_IFREG) #endif // S_ISREG #ifndef S_ISDIR -#define S_ISDIR(m) (((m)&S_IFDIR) == S_IFDIR) +#define S_ISDIR(m) (((m) & S_IFDIR) == S_IFDIR) #endif // S_ISDIR #ifndef NOMINMAX @@ -269,7 +293,12 @@ using socket_t = int; #include #include -#if OPENSSL_VERSION_NUMBER < 0x30000000L +#if defined(OPENSSL_IS_BORINGSSL) || defined(LIBRESSL_VERSION_NUMBER) +#if OPENSSL_VERSION_NUMBER < 0x1010107f +#error Please use OpenSSL or a current version of BoringSSL +#endif +#define SSL_get1_peer_certificate SSL_get_peer_certificate +#elif OPENSSL_VERSION_NUMBER < 0x30000000L #error Sorry, OpenSSL versions prior to 3.0.0 are not supported #endif @@ -312,16 +341,63 @@ make_unique(std::size_t n) { return std::unique_ptr(new RT[n]); } -struct ci { - bool operator()(const std::string &s1, const std::string &s2) const { - return std::lexicographical_compare(s1.begin(), s1.end(), s2.begin(), - s2.end(), - [](unsigned char c1, unsigned char c2) { - return ::tolower(c1) < ::tolower(c2); - }); +namespace case_ignore { + +inline unsigned char to_lower(int c) { + const static unsigned char table[256] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 60, 61, 62, 63, 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, + 122, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, + 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, + 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, + 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, + 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, + 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, + 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 224, 225, 226, + 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, + 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, 224, + 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, + 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, + 255, + }; + return table[(unsigned char)(char)c]; +} + +inline bool equal(const std::string &a, const std::string &b) { + return a.size() == b.size() && + std::equal(a.begin(), a.end(), b.begin(), [](char ca, char cb) { + return to_lower(ca) == to_lower(cb); + }); +} + +struct equal_to { + bool operator()(const std::string &a, const std::string &b) const { + return equal(a, b); } }; +struct hash { + size_t operator()(const std::string &key) const { + return hash_core(key.data(), key.size(), 0); + } + + size_t hash_core(const char *s, size_t l, size_t h) const { + return (l == 0) ? h + : hash_core(s + 1, l - 1, + // Unsets the 6 high bits of h, therefore no + // overflow happens + (((std::numeric_limits::max)() >> 6) & + h * 33) ^ + static_cast(to_lower(*s))); + } +}; + +} // namespace case_ignore + // This is based on // "http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2014/n4189". @@ -427,7 +503,9 @@ enum StatusCode { NetworkAuthenticationRequired_511 = 511, }; -using Headers = std::multimap; +using Headers = + std::unordered_multimap; using Params = std::multimap; using Match = std::smatch; @@ -534,6 +612,7 @@ using Ranges = std::vector; struct Request { std::string method; std::string path; + Params params; Headers headers; std::string body; @@ -545,11 +624,11 @@ struct Request { // for server std::string version; std::string target; - Params params; MultipartFormDataMap files; Ranges ranges; Match matches; std::unordered_map path_params; + std::function is_connection_closed = []() { return true; }; // for client ResponseHandler response_handler; @@ -560,8 +639,10 @@ struct Request { #endif bool has_header(const std::string &key) const; - std::string get_header_value(const std::string &key, size_t id = 0) const; - uint64_t get_header_value_u64(const std::string &key, size_t id = 0) const; + std::string get_header_value(const std::string &key, const char *def = "", + size_t id = 0) const; + uint64_t get_header_value_u64(const std::string &key, uint64_t def = 0, + size_t id = 0) const; size_t get_header_value_count(const std::string &key) const; void set_header(const std::string &key, const std::string &val); @@ -592,8 +673,10 @@ struct Response { std::string location; // Redirect location bool has_header(const std::string &key) const; - std::string get_header_value(const std::string &key, size_t id = 0) const; - uint64_t get_header_value_u64(const std::string &key, size_t id = 0) const; + std::string get_header_value(const std::string &key, const char *def = "", + size_t id = 0) const; + uint64_t get_header_value_u64(const std::string &key, uint64_t def = 0, + size_t id = 0) const; size_t get_header_value_count(const std::string &key) const; void set_header(const std::string &key, const std::string &val); @@ -614,6 +697,10 @@ struct Response { const std::string &content_type, ContentProviderWithoutLength provider, ContentProviderResourceReleaser resource_releaser = nullptr); + void set_file_content(const std::string &path, + const std::string &content_type); + void set_file_content(const std::string &path); + Response() = default; Response(const Response &) = default; Response &operator=(const Response &) = default; @@ -631,6 +718,8 @@ struct Response { ContentProviderResourceReleaser content_provider_resource_releaser_; bool is_chunked_content_provider_ = false; bool content_provider_success_ = false; + std::string file_content_path_; + std::string file_content_content_type_; }; class Stream { @@ -646,8 +735,6 @@ public: virtual void get_local_ip_and_port(std::string &ip, int &port) const = 0; virtual socket_t socket() const = 0; - template - ssize_t write_format(const char *fmt, const Args &...args); ssize_t write(const char *ptr); ssize_t write(const std::string &s); }; @@ -719,13 +806,18 @@ private: if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } - fn = std::move(pool_.jobs_.front()); + fn = pool_.jobs_.front(); pool_.jobs_.pop_front(); } assert(true == static_cast(fn)); fn(); } + +#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(OPENSSL_IS_BORINGSSL) && \ + !defined(LIBRESSL_VERSION_NUMBER) + OPENSSL_thread_stop(); +#endif } ThreadPool &pool_; @@ -787,7 +879,6 @@ public: bool match(Request &request) const override; private: - static constexpr char marker = ':'; // Treat segment separators as the end of path parameter capture // Does not need to handle query parameters as they are parsed before path // matching @@ -871,8 +962,13 @@ public: Server &set_default_file_mimetype(const std::string &mime); Server &set_file_request_handler(Handler handler); - Server &set_error_handler(HandlerWithResponse handler); - Server &set_error_handler(Handler handler); + template + Server &set_error_handler(ErrorHandlerFunc &&handler) { + return set_error_handler_core( + std::forward(handler), + std::is_convertible{}); + } + Server &set_exception_handler(ExceptionHandler handler); Server &set_pre_routing_handler(HandlerWithResponse handler); Server &set_post_routing_handler(Handler handler); @@ -882,6 +978,7 @@ public: Server &set_address_family(int family); Server &set_tcp_nodelay(bool on); + Server &set_ipv6_v6only(bool on); Server &set_socket_options(SocketOptions socket_options); Server &set_default_headers(Headers headers); @@ -914,21 +1011,24 @@ public: bool is_running() const; void wait_until_ready() const; void stop(); + void decommission(); std::function new_task_queue; protected: - bool process_request(Stream &strm, bool close_connection, + bool process_request(Stream &strm, const std::string &remote_addr, + int remote_port, const std::string &local_addr, + int local_port, bool close_connection, bool &connection_closed, const std::function &setup_request); std::atomic svr_sock_{INVALID_SOCKET}; size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; time_t keep_alive_timeout_sec_ = CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND; - time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; - time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; - time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND; - time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND; time_t idle_interval_sec_ = CPPHTTPLIB_IDLE_INTERVAL_SECOND; time_t idle_interval_usec_ = CPPHTTPLIB_IDLE_INTERVAL_USECOND; size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; @@ -943,6 +1043,9 @@ private: static std::unique_ptr make_matcher(const std::string &pattern); + Server &set_error_handler_core(HandlerWithResponse handler, std::true_type); + Server &set_error_handler_core(Handler handler, std::false_type); + socket_t create_server_socket(const std::string &host, int port, int socket_flags, SocketOptions socket_options) const; @@ -985,7 +1088,7 @@ private: virtual bool process_and_close_socket(socket_t sock); std::atomic is_running_{false}; - std::atomic done_{false}; + std::atomic is_decommisioned{false}; struct MountPointEntry { std::string mount_point; @@ -1018,6 +1121,7 @@ private: int address_family_ = AF_UNSPEC; bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY; SocketOptions socket_options_ = default_socket_options; Headers default_headers_; @@ -1037,6 +1141,7 @@ enum class Error { SSLConnection, SSLLoadingCerts, SSLServerVerification, + SSLServerHostnameVerification, UnsupportedMultipartBoundaryChars, Compression, ConnectionTimeout, @@ -1074,9 +1179,10 @@ public: // Request Headers bool has_request_header(const std::string &key) const; std::string get_request_header_value(const std::string &key, + const char *def = "", size_t id = 0) const; uint64_t get_request_header_value_u64(const std::string &key, - size_t id = 0) const; + uint64_t def = 0, size_t id = 0) const; size_t get_request_header_value_count(const std::string &key) const; private: @@ -1140,10 +1246,18 @@ public: const std::string &content_type); Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); Result Post(const std::string &path, const std::string &body, const std::string &content_type); + Result Post(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type); @@ -1159,6 +1273,8 @@ public: Result Post(const std::string &path, const Params ¶ms); Result Post(const std::string &path, const Headers &headers, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress); Result Post(const std::string &path, const MultipartFormDataItems &items); Result Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items); @@ -1173,10 +1289,18 @@ public: const std::string &content_type); Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); Result Put(const std::string &path, const std::string &body, const std::string &content_type); + Result Put(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type); Result Put(const std::string &path, @@ -1191,6 +1315,8 @@ public: Result Put(const std::string &path, const Params ¶ms); Result Put(const std::string &path, const Headers &headers, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress); Result Put(const std::string &path, const MultipartFormDataItems &items); Result Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items); @@ -1203,13 +1329,23 @@ public: Result Patch(const std::string &path); Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Patch(const std::string &path, const char *body, size_t content_length, + const std::string &content_type, Progress progress); Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress); Result Patch(const std::string &path, const std::string &body, const std::string &content_type); + Result Patch(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type); @@ -1227,13 +1363,24 @@ public: Result Delete(const std::string &path, const Headers &headers); Result Delete(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Delete(const std::string &path, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); Result Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress); Result Delete(const std::string &path, const std::string &body, const std::string &content_type); + Result Delete(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); Result Delete(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); Result Options(const std::string &path); Result Options(const std::string &path, const Headers &headers); @@ -1258,6 +1405,7 @@ public: void set_address_family(int family); void set_tcp_nodelay(bool on); + void set_ipv6_v6only(bool on); void set_socket_options(SocketOptions socket_options); void set_connection_timeout(time_t sec, time_t usec = 0); @@ -1309,6 +1457,8 @@ public: #ifdef CPPHTTPLIB_OPENSSL_SUPPORT void enable_server_certificate_verification(bool enabled); + void enable_server_hostname_verification(bool enabled); + void set_server_certificate_verifier(std::function verifier); #endif void set_logger(Logger logger); @@ -1375,10 +1525,10 @@ protected: time_t connection_timeout_sec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND; time_t connection_timeout_usec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND; - time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; - time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; - time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND; - time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND; std::string basic_auth_username_; std::string basic_auth_password_; @@ -1395,6 +1545,7 @@ protected: int address_family_ = AF_UNSPEC; bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY; SocketOptions socket_options_ = nullptr; bool compress_ = false; @@ -1422,6 +1573,8 @@ protected: #ifdef CPPHTTPLIB_OPENSSL_SUPPORT bool server_certificate_verification_ = true; + bool server_hostname_verification_ = true; + std::function server_certificate_verifier_; #endif Logger logger_; @@ -1430,6 +1583,9 @@ private: bool send_(Request &req, Response &res, Error &error); Result send_(Request &&req); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool is_ssl_peer_could_be_closed(SSL *ssl) const; +#endif socket_t create_client_socket(Error &error) const; bool read_response_line(Stream &strm, const Request &req, Response &res) const; @@ -1448,7 +1604,7 @@ private: const Headers &headers, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type); + const std::string &content_type, Progress progress); ContentProviderWithoutLength get_multipart_content_provider( const std::string &boundary, const MultipartFormDataItems &items, const MultipartFormDataProviderItems &provider_items) const; @@ -1477,6 +1633,7 @@ public: const std::string &client_key_path); Client(Client &&) = default; + Client &operator=(Client &&) = default; ~Client(); @@ -1523,10 +1680,18 @@ public: const std::string &content_type); Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); Result Post(const std::string &path, const std::string &body, const std::string &content_type); + Result Post(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type); @@ -1542,6 +1707,8 @@ public: Result Post(const std::string &path, const Params ¶ms); Result Post(const std::string &path, const Headers &headers, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress); Result Post(const std::string &path, const MultipartFormDataItems &items); Result Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items); @@ -1556,10 +1723,18 @@ public: const std::string &content_type); Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); Result Put(const std::string &path, const std::string &body, const std::string &content_type); + Result Put(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type); Result Put(const std::string &path, @@ -1574,6 +1749,8 @@ public: Result Put(const std::string &path, const Params ¶ms); Result Put(const std::string &path, const Headers &headers, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress); Result Put(const std::string &path, const MultipartFormDataItems &items); Result Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items); @@ -1586,13 +1763,23 @@ public: Result Patch(const std::string &path); Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Patch(const std::string &path, const char *body, size_t content_length, + const std::string &content_type, Progress progress); Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress); Result Patch(const std::string &path, const std::string &body, const std::string &content_type); + Result Patch(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type); @@ -1610,13 +1797,24 @@ public: Result Delete(const std::string &path, const Headers &headers); Result Delete(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Delete(const std::string &path, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); Result Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress); Result Delete(const std::string &path, const std::string &body, const std::string &content_type); + Result Delete(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); Result Delete(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); Result Options(const std::string &path); Result Options(const std::string &path, const Headers &headers); @@ -1685,6 +1883,8 @@ public: #ifdef CPPHTTPLIB_OPENSSL_SUPPORT void enable_server_certificate_verification(bool enabled); + void enable_server_hostname_verification(bool enabled); + void set_server_certificate_verifier(std::function verifier); #endif void set_logger(Logger logger); @@ -1730,6 +1930,9 @@ public: SSL_CTX *ssl_context() const; + void update_certs(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store = nullptr); + private: bool process_and_close_socket(socket_t sock) override; @@ -1810,68 +2013,58 @@ inline void duration_to_sec_and_usec(const T &duration, U callback) { callback(static_cast(sec), static_cast(usec)); } +inline bool is_numeric(const std::string &str) { + return !str.empty() && std::all_of(str.begin(), str.end(), ::isdigit); +} + inline uint64_t get_header_value_u64(const Headers &headers, - const std::string &key, size_t id, - uint64_t def) { + const std::string &key, uint64_t def, + size_t id, bool &is_invalid_value) { + is_invalid_value = false; auto rng = headers.equal_range(key); auto it = rng.first; std::advance(it, static_cast(id)); if (it != rng.second) { - return std::strtoull(it->second.data(), nullptr, 10); + if (is_numeric(it->second)) { + return std::strtoull(it->second.data(), nullptr, 10); + } else { + is_invalid_value = true; + } } return def; } +inline uint64_t get_header_value_u64(const Headers &headers, + const std::string &key, uint64_t def, + size_t id) { + bool dummy = false; + return get_header_value_u64(headers, key, def, id, dummy); +} + } // namespace detail inline uint64_t Request::get_header_value_u64(const std::string &key, - size_t id) const { - return detail::get_header_value_u64(headers, key, id, 0); + uint64_t def, size_t id) const { + return detail::get_header_value_u64(headers, key, def, id); } inline uint64_t Response::get_header_value_u64(const std::string &key, - size_t id) const { - return detail::get_header_value_u64(headers, key, id, 0); -} - -template -inline ssize_t Stream::write_format(const char *fmt, const Args &...args) { - const auto bufsiz = 2048; - std::array buf{}; - - auto sn = snprintf(buf.data(), buf.size() - 1, fmt, args...); - if (sn <= 0) { return sn; } - - auto n = static_cast(sn); - - if (n >= buf.size() - 1) { - std::vector glowable_buf(buf.size()); - - while (n >= glowable_buf.size() - 1) { - glowable_buf.resize(glowable_buf.size() * 2); - n = static_cast( - snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...)); - } - return write(&glowable_buf[0], n); - } else { - return write(buf.data(), n); - } + uint64_t def, size_t id) const { + return detail::get_header_value_u64(headers, key, def, id); } inline void default_socket_options(socket_t sock) { - int yes = 1; + int opt = 1; #ifdef _WIN32 setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, - reinterpret_cast(&yes), sizeof(yes)); - setsockopt(sock, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, - reinterpret_cast(&yes), sizeof(yes)); + reinterpret_cast(&opt), sizeof(opt)); #else #ifdef SO_REUSEPORT setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, - reinterpret_cast(&yes), sizeof(yes)); + reinterpret_cast(&opt), sizeof(opt)); #else setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, - reinterpret_cast(&yes), sizeof(yes)); + reinterpret_cast(&opt), sizeof(opt)); #endif #endif } @@ -1997,6 +2190,8 @@ inline std::string to_string(const Error error) { case Error::SSLConnection: return "SSL connection failed"; case Error::SSLLoadingCerts: return "SSL certificate loading failed"; case Error::SSLServerVerification: return "SSL server verification failed"; + case Error::SSLServerHostnameVerification: + return "SSL server hostname verification failed"; case Error::UnsupportedMultipartBoundaryChars: return "Unsupported HTTP multipart boundary characters"; case Error::Compression: return "Compression failed"; @@ -2016,8 +2211,9 @@ inline std::ostream &operator<<(std::ostream &os, const Error &obj) { } inline uint64_t Result::get_request_header_value_u64(const std::string &key, + uint64_t def, size_t id) const { - return detail::get_header_value_u64(request_headers_, key, id, 0); + return detail::get_header_value_u64(request_headers_, key, def, id); } template @@ -2080,6 +2276,36 @@ make_basic_authentication_header(const std::string &username, namespace detail { +#if defined(_WIN32) +inline std::wstring u8string_to_wstring(const char *s) { + std::wstring ws; + auto len = static_cast(strlen(s)); + auto wlen = ::MultiByteToWideChar(CP_UTF8, 0, s, len, nullptr, 0); + if (wlen > 0) { + ws.resize(wlen); + wlen = ::MultiByteToWideChar( + CP_UTF8, 0, s, len, + const_cast(reinterpret_cast(ws.data())), wlen); + if (wlen != static_cast(ws.size())) { ws.clear(); } + } + return ws; +} +#endif + +struct FileStat { + FileStat(const std::string &path); + bool is_file() const; + bool is_dir() const; + +private: +#if defined(_WIN32) + struct _stat st_; +#else + struct stat st_; +#endif + int ret_ = -1; +}; + std::string encode_query_param(const std::string &value); std::string decode_url(const std::string &s, bool convert_plus_to_space); @@ -2088,6 +2314,16 @@ void read_file(const std::string &path, std::string &out); std::string trim_copy(const std::string &s); +void divide( + const char *data, std::size_t size, char d, + std::function + fn); + +void divide( + const std::string &str, char d, + std::function + fn); + void split(const char *b, const char *e, char d, std::function fn); @@ -2099,18 +2335,23 @@ bool process_client_socket(socket_t sock, time_t read_timeout_sec, time_t write_timeout_usec, std::function callback); -socket_t create_client_socket( - const std::string &host, const std::string &ip, int port, - int address_family, bool tcp_nodelay, SocketOptions socket_options, - time_t connection_timeout_sec, time_t connection_timeout_usec, - time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, - time_t write_timeout_usec, const std::string &intf, Error &error); +socket_t create_client_socket(const std::string &host, const std::string &ip, + int port, int address_family, bool tcp_nodelay, + bool ipv6_v6only, SocketOptions socket_options, + time_t connection_timeout_sec, + time_t connection_timeout_usec, + time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec, + const std::string &intf, Error &error); const char *get_header_value(const Headers &headers, const std::string &key, - size_t id = 0, const char *def = nullptr); + const char *def, size_t id); std::string params_to_query_str(const Params ¶ms); +void parse_query_text(const char *data, std::size_t size, Params ¶ms); + void parse_query_text(const std::string &s, Params ¶ms); bool parse_multipart_boundary(const std::string &content_type, @@ -2270,15 +2511,70 @@ public: private: #if defined(_WIN32) - HANDLE hFile_; - HANDLE hMapping_; + HANDLE hFile_ = NULL; + HANDLE hMapping_ = NULL; #else - int fd_; + int fd_ = -1; #endif - size_t size_; - void *addr_; + size_t size_ = 0; + void *addr_ = nullptr; + bool is_open_empty_file = false; }; +// NOTE: https://www.rfc-editor.org/rfc/rfc9110#section-5 +namespace fields { + +inline bool is_token_char(char c) { + return std::isalnum(c) || c == '!' || c == '#' || c == '$' || c == '%' || + c == '&' || c == '\'' || c == '*' || c == '+' || c == '-' || + c == '.' || c == '^' || c == '_' || c == '`' || c == '|' || c == '~'; +} + +inline bool is_token(const std::string &s) { + if (s.empty()) { return false; } + for (auto c : s) { + if (!is_token_char(c)) { return false; } + } + return true; +} + +inline bool is_field_name(const std::string &s) { return is_token(s); } + +inline bool is_vchar(char c) { return c >= 33 && c <= 126; } + +inline bool is_obs_text(char c) { return 128 <= static_cast(c); } + +inline bool is_field_vchar(char c) { return is_vchar(c) || is_obs_text(c); } + +inline bool is_field_content(const std::string &s) { + if (s.empty()) { return false; } + + if (s.size() == 1) { + return is_field_vchar(s[0]); + } else if (s.size() == 2) { + return is_field_vchar(s[0]) && is_field_vchar(s[1]); + } else { + size_t i = 0; + + if (!is_field_vchar(s[i])) { return false; } + i++; + + while (i < s.size() - 1) { + auto c = s[i++]; + if (c == ' ' || c == '\t' || is_field_vchar(c)) { + } else { + return false; + } + } + + return is_field_vchar(s[i]); + } +} + +inline bool is_field_value(const std::string &s) { return is_field_content(s); } + +} // namespace fields + } // namespace detail // ---------------------------------------------------------------------------- @@ -2392,20 +2688,6 @@ inline std::string base64_encode(const std::string &in) { return out; } -inline bool is_file(const std::string &path) { -#ifdef _WIN32 - return _access_s(path.c_str(), 0) == 0; -#else - struct stat st; - return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); -#endif -} - -inline bool is_dir(const std::string &path) { - struct stat st; - return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); -} - inline bool is_valid_path(const std::string &path) { size_t level = 0; size_t i = 0; @@ -2448,6 +2730,21 @@ inline bool is_valid_path(const std::string &path) { return true; } +inline FileStat::FileStat(const std::string &path) { +#if defined(_WIN32) + auto wpath = u8string_to_wstring(path.c_str()); + ret_ = _wstat(wpath.c_str(), &st_); +#else + ret_ = stat(path.c_str(), &st_); +#endif +} +inline bool FileStat::is_file() const { + return ret_ >= 0 && S_ISREG(st_.st_mode); +} +inline bool FileStat::is_dir() const { + return ret_ >= 0 && S_ISDIR(st_.st_mode); +} + inline std::string encode_query_param(const std::string &value) { std::ostringstream escaped; escaped.fill('0'); @@ -2579,6 +2876,27 @@ inline std::string trim_double_quotes_copy(const std::string &s) { return s; } +inline void +divide(const char *data, std::size_t size, char d, + std::function + fn) { + const auto it = std::find(data, data + size, d); + const auto found = static_cast(it != data + size); + const auto lhs_data = data; + const auto lhs_size = static_cast(it - data); + const auto rhs_data = it + found; + const auto rhs_size = size - lhs_size - found; + + fn(lhs_data, lhs_size, rhs_data, rhs_size); +} + +inline void +divide(const std::string &str, char d, + std::function + fn) { + divide(str.data(), str.size(), d, std::move(fn)); +} + inline void split(const char *b, const char *e, char d, std::function fn) { return split(b, e, d, (std::numeric_limits::max)(), std::move(fn)); @@ -2636,6 +2954,10 @@ inline bool stream_line_reader::getline() { fixed_buffer_used_size_ = 0; glowable_buffer_.clear(); +#ifndef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + char prev_byte = 0; +#endif + for (size_t i = 0;; i++) { char byte; auto n = strm_.read(&byte, 1); @@ -2652,7 +2974,12 @@ inline bool stream_line_reader::getline() { append(byte); +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR if (byte == '\n') { break; } +#else + if (prev_byte == '\r' && byte == '\n') { break; } + prev_byte = byte; +#endif } return true; @@ -2671,16 +2998,7 @@ inline void stream_line_reader::append(char c) { } } -inline mmap::mmap(const char *path) -#if defined(_WIN32) - : hFile_(NULL), hMapping_(NULL) -#else - : fd_(-1) -#endif - , - size_(0), addr_(nullptr) { - open(path); -} +inline mmap::mmap(const char *path) { open(path); } inline mmap::~mmap() { close(); } @@ -2688,29 +3006,60 @@ inline bool mmap::open(const char *path) { close(); #if defined(_WIN32) - std::wstring wpath; - for (size_t i = 0; i < strlen(path); i++) { - wpath += path[i]; - } + auto wpath = u8string_to_wstring(path); + if (wpath.empty()) { return false; } +#if _WIN32_WINNT >= _WIN32_WINNT_WIN8 hFile_ = ::CreateFile2(wpath.c_str(), GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING, NULL); +#else + hFile_ = ::CreateFileW(wpath.c_str(), GENERIC_READ, FILE_SHARE_READ, NULL, + OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL); +#endif if (hFile_ == INVALID_HANDLE_VALUE) { return false; } LARGE_INTEGER size{}; if (!::GetFileSizeEx(hFile_, &size)) { return false; } + // If the following line doesn't compile due to QuadPart, update Windows SDK. + // See: + // https://github.com/yhirose/cpp-httplib/issues/1903#issuecomment-2316520721 + if (static_cast(size.QuadPart) > + (std::numeric_limits::max)()) { + // `size_t` might be 32-bits, on 32-bits Windows. + return false; + } size_ = static_cast(size.QuadPart); +#if _WIN32_WINNT >= _WIN32_WINNT_WIN8 hMapping_ = ::CreateFileMappingFromApp(hFile_, NULL, PAGE_READONLY, size_, NULL); +#else + hMapping_ = ::CreateFileMappingW(hFile_, NULL, PAGE_READONLY, 0, 0, NULL); +#endif + + // Special treatment for an empty file... + if (hMapping_ == NULL && size_ == 0) { + close(); + is_open_empty_file = true; + return true; + } if (hMapping_ == NULL) { close(); return false; } +#if _WIN32_WINNT >= _WIN32_WINNT_WIN8 addr_ = ::MapViewOfFileFromApp(hMapping_, FILE_MAP_READ, 0, 0); +#else + addr_ = ::MapViewOfFile(hMapping_, FILE_MAP_READ, 0, 0, 0); +#endif + + if (addr_ == nullptr) { + close(); + return false; + } #else fd_ = ::open(path, O_RDONLY); if (fd_ == -1) { return false; } @@ -2723,22 +3072,26 @@ inline bool mmap::open(const char *path) { size_ = static_cast(sb.st_size); addr_ = ::mmap(NULL, size_, PROT_READ, MAP_PRIVATE, fd_, 0); -#endif - if (addr_ == nullptr) { + // Special treatment for an empty file... + if (addr_ == MAP_FAILED && size_ == 0) { close(); + is_open_empty_file = true; return false; } +#endif return true; } -inline bool mmap::is_open() const { return addr_ != nullptr; } +inline bool mmap::is_open() const { + return is_open_empty_file ? true : addr_ != nullptr; +} inline size_t mmap::size() const { return size_; } inline const char *mmap::data() const { - return static_cast(addr_); + return is_open_empty_file ? "" : static_cast(addr_); } inline void mmap::close() { @@ -2757,6 +3110,8 @@ inline void mmap::close() { ::CloseHandle(hFile_); hFile_ = INVALID_HANDLE_VALUE; } + + is_open_empty_file = false; #else if (addr_ != nullptr) { munmap(addr_, size_); @@ -2782,7 +3137,10 @@ template inline ssize_t handle_EINTR(T fn) { ssize_t res = 0; while (true) { res = fn(); - if (res < 0 && errno == EINTR) { continue; } + if (res < 0 && errno == EINTR) { + std::this_thread::sleep_for(std::chrono::microseconds{1}); + continue; + } break; } return res; @@ -2991,23 +3349,37 @@ private: }; #endif -inline bool keep_alive(socket_t sock, time_t keep_alive_timeout_sec) { +inline bool keep_alive(const std::atomic &svr_sock, socket_t sock, + time_t keep_alive_timeout_sec) { using namespace std::chrono; - auto start = steady_clock::now(); + + const auto interval_usec = + CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND; + + // Avoid expensive `steady_clock::now()` call for the first time + if (select_read(sock, 0, interval_usec) > 0) { return true; } + + const auto start = steady_clock::now() - microseconds{interval_usec}; + const auto timeout = seconds{keep_alive_timeout_sec}; + while (true) { - auto val = select_read(sock, 0, 10000); + if (svr_sock == INVALID_SOCKET) { + break; // Server socket is closed + } + + auto val = select_read(sock, 0, interval_usec); if (val < 0) { - return false; + break; // Ssocket error } else if (val == 0) { - auto current = steady_clock::now(); - auto duration = duration_cast(current - start); - auto timeout = keep_alive_timeout_sec * 1000; - if (duration.count() > timeout) { return false; } - std::this_thread::sleep_for(std::chrono::milliseconds(1)); + if (steady_clock::now() - start > timeout) { + break; // Timeout + } } else { - return true; + return true; // Ready for read } } + + return false; } template @@ -3018,8 +3390,7 @@ process_server_socket_core(const std::atomic &svr_sock, socket_t sock, assert(keep_alive_max_count > 0); auto ret = false; auto count = keep_alive_max_count; - while (svr_sock != INVALID_SOCKET && count > 0 && - keep_alive(sock, keep_alive_timeout_sec)) { + while (count > 0 && keep_alive(svr_sock, sock, keep_alive_timeout_sec)) { auto close_connection = count == 1; auto connection_closed = false; ret = callback(close_connection, connection_closed); @@ -3063,10 +3434,29 @@ inline int shutdown_socket(socket_t sock) { #endif } +inline std::string escape_abstract_namespace_unix_domain(const std::string &s) { + if (s.size() > 1 && s[0] == '\0') { + auto ret = s; + ret[0] = '@'; + return ret; + } + return s; +} + +inline std::string +unescape_abstract_namespace_unix_domain(const std::string &s) { + if (s.size() > 1 && s[0] == '@') { + auto ret = s; + ret[0] = '\0'; + return ret; + } + return s; +} + template socket_t create_socket(const std::string &host, const std::string &ip, int port, int address_family, int socket_flags, bool tcp_nodelay, - SocketOptions socket_options, + bool ipv6_v6only, SocketOptions socket_options, BindOrConnect bind_or_connect) { // Get address info const char *node = nullptr; @@ -3075,7 +3465,7 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, memset(&hints, 0, sizeof(struct addrinfo)); hints.ai_socktype = SOCK_STREAM; - hints.ai_protocol = 0; + hints.ai_protocol = IPPROTO_IP; if (!ip.empty()) { node = ip.c_str(); @@ -3093,20 +3483,32 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, const auto addrlen = host.length(); if (addrlen > sizeof(sockaddr_un::sun_path)) { return INVALID_SOCKET; } +#ifdef SOCK_CLOEXEC + auto sock = socket(hints.ai_family, hints.ai_socktype | SOCK_CLOEXEC, + hints.ai_protocol); +#else auto sock = socket(hints.ai_family, hints.ai_socktype, hints.ai_protocol); +#endif + if (sock != INVALID_SOCKET) { sockaddr_un addr{}; addr.sun_family = AF_UNIX; - std::copy(host.begin(), host.end(), addr.sun_path); + + auto unescaped_host = unescape_abstract_namespace_unix_domain(host); + std::copy(unescaped_host.begin(), unescaped_host.end(), addr.sun_path); hints.ai_addr = reinterpret_cast(&addr); hints.ai_addrlen = static_cast( sizeof(addr) - sizeof(addr.sun_path) + addrlen); +#ifndef SOCK_CLOEXEC fcntl(sock, F_SETFD, FD_CLOEXEC); +#endif + if (socket_options) { socket_options(sock); } - if (!bind_or_connect(sock, hints)) { + bool dummy; + if (!bind_or_connect(sock, hints, dummy)) { close_socket(sock); sock = INVALID_SOCKET; } @@ -3123,6 +3525,7 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, #endif return INVALID_SOCKET; } + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); for (auto rp = result; rp; rp = rp->ai_next) { // Create a socket @@ -3148,11 +3551,18 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); } #else + +#ifdef SOCK_CLOEXEC + auto sock = + socket(rp->ai_family, rp->ai_socktype | SOCK_CLOEXEC, rp->ai_protocol); +#else auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); +#endif + #endif if (sock == INVALID_SOCKET) { continue; } -#ifndef _WIN32 +#if !defined _WIN32 && !defined SOCK_CLOEXEC if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { close_socket(sock); continue; @@ -3160,39 +3570,38 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, #endif if (tcp_nodelay) { - auto yes = 1; + auto opt = 1; #ifdef _WIN32 setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, - reinterpret_cast(&yes), sizeof(yes)); + reinterpret_cast(&opt), sizeof(opt)); #else setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, - reinterpret_cast(&yes), sizeof(yes)); + reinterpret_cast(&opt), sizeof(opt)); +#endif + } + + if (rp->ai_family == AF_INET6) { + auto opt = ipv6_v6only ? 1 : 0; +#ifdef _WIN32 + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, + reinterpret_cast(&opt), sizeof(opt)); +#else + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, + reinterpret_cast(&opt), sizeof(opt)); #endif } if (socket_options) { socket_options(sock); } - if (rp->ai_family == AF_INET6) { - auto no = 0; -#ifdef _WIN32 - setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, - reinterpret_cast(&no), sizeof(no)); -#else - setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, - reinterpret_cast(&no), sizeof(no)); -#endif - } - // bind or connect - if (bind_or_connect(sock, *rp)) { - freeaddrinfo(result); - return sock; - } + auto quit = false; + if (bind_or_connect(sock, *rp, quit)) { return sock; } close_socket(sock); + + if (quit) { break; } } - freeaddrinfo(result); return INVALID_SOCKET; } @@ -3225,6 +3634,7 @@ inline bool bind_ip_address(socket_t sock, const std::string &host) { hints.ai_protocol = 0; if (getaddrinfo(host.c_str(), "0", &hints, &result)) { return false; } + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); auto ret = false; for (auto rp = result; rp; rp = rp->ai_next) { @@ -3235,7 +3645,6 @@ inline bool bind_ip_address(socket_t sock, const std::string &host) { } } - freeaddrinfo(result); return ret; } @@ -3247,6 +3656,8 @@ inline bool bind_ip_address(socket_t sock, const std::string &host) { inline std::string if2ip(int address_family, const std::string &ifn) { struct ifaddrs *ifap; getifaddrs(&ifap); + auto se = detail::scope_exit([&] { freeifaddrs(ifap); }); + std::string addr_candidate; for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { if (ifa->ifa_addr && ifn == ifa->ifa_name && @@ -3256,7 +3667,6 @@ inline std::string if2ip(int address_family, const std::string &ifn) { auto sa = reinterpret_cast(ifa->ifa_addr); char buf[INET_ADDRSTRLEN]; if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) { - freeifaddrs(ifap); return std::string(buf, INET_ADDRSTRLEN); } } else if (ifa->ifa_addr->sa_family == AF_INET6) { @@ -3269,7 +3679,6 @@ inline std::string if2ip(int address_family, const std::string &ifn) { if (s6_addr_head == 0xfc || s6_addr_head == 0xfd) { addr_candidate = std::string(buf, INET6_ADDRSTRLEN); } else { - freeifaddrs(ifap); return std::string(buf, INET6_ADDRSTRLEN); } } @@ -3277,20 +3686,21 @@ inline std::string if2ip(int address_family, const std::string &ifn) { } } } - freeifaddrs(ifap); return addr_candidate; } #endif inline socket_t create_client_socket( const std::string &host, const std::string &ip, int port, - int address_family, bool tcp_nodelay, SocketOptions socket_options, - time_t connection_timeout_sec, time_t connection_timeout_usec, - time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + int address_family, bool tcp_nodelay, bool ipv6_v6only, + SocketOptions socket_options, time_t connection_timeout_sec, + time_t connection_timeout_usec, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, time_t write_timeout_usec, const std::string &intf, Error &error) { auto sock = create_socket( - host, ip, port, address_family, 0, tcp_nodelay, std::move(socket_options), - [&](socket_t sock2, struct addrinfo &ai) -> bool { + host, ip, port, address_family, 0, tcp_nodelay, ipv6_v6only, + std::move(socket_options), + [&](socket_t sock2, struct addrinfo &ai, bool &quit) -> bool { if (!intf.empty()) { #ifdef USE_IF2IP auto ip_from_if = if2ip(address_family, intf); @@ -3314,7 +3724,10 @@ inline socket_t create_client_socket( } error = wait_until_socket_is_ready(sock2, connection_timeout_sec, connection_timeout_usec); - if (error != Error::Success) { return false; } + if (error != Error::Success) { + if (error == Error::ConnectionTimeout) { quit = true; } + return false; + } } set_nonblocking(sock2, false); @@ -3439,7 +3852,7 @@ inline unsigned int str2tag(const std::string &s) { namespace udl { -inline constexpr unsigned int operator"" _t(const char *s, size_t l) { +inline constexpr unsigned int operator""_t(const char *s, size_t l) { return str2tag_core(s, l, 0); } @@ -3524,8 +3937,9 @@ inline bool can_compress_content_type(const std::string &content_type) { case "application/protobuf"_t: case "application/xhtml+xml"_t: return true; - default: - return !content_type.rfind("text/", 0) && tag != "text/event-stream"_t; + case "text/event-stream"_t: return false; + + default: return !content_type.rfind("text/", 0); } } @@ -3762,8 +4176,8 @@ inline bool has_header(const Headers &headers, const std::string &key) { } inline const char *get_header_value(const Headers &headers, - const std::string &key, size_t id, - const char *def) { + const std::string &key, const char *def, + size_t id) { auto rng = headers.equal_range(key); auto it = rng.first; std::advance(it, static_cast(id)); @@ -3771,14 +4185,6 @@ inline const char *get_header_value(const Headers &headers, return def; } -inline bool compare_case_ignore(const std::string &a, const std::string &b) { - if (a.size() != b.size()) { return false; } - for (size_t i = 0; i < b.size(); i++) { - if (::tolower(a[i]) != ::tolower(b[i])) { return false; } - } - return true; -} - template inline bool parse_header(const char *beg, const char *end, T fn) { // Skip trailing spaces and tabs. @@ -3801,15 +4207,27 @@ inline bool parse_header(const char *beg, const char *end, T fn) { p++; } - if (p < end) { + if (p <= end) { auto key_len = key_end - beg; if (!key_len) { return false; } auto key = std::string(beg, key_end); - auto val = compare_case_ignore(key, "Location") + auto val = case_ignore::equal(key, "Location") ? std::string(p, end) : decode_url(std::string(p, end), false); - fn(std::move(key), std::move(val)); + + // NOTE: From RFC 9110: + // Field values containing CR, LF, or NUL characters are + // invalid and dangerous, due to the varying ways that + // implementations might parse and interpret those + // characters; a recipient of CR, LF, or NUL within a field + // value MUST either reject the message or replace each of + // those characters with SP before further processing or + // forwarding of that message. + static const std::string CR_LF_NUL("\r\n\0", 3); + if (val.find_first_of(CR_LF_NUL) != std::string::npos) { return false; } + + fn(key, val); return true; } @@ -3829,27 +4247,27 @@ inline bool read_headers(Stream &strm, Headers &headers) { if (line_reader.end_with_crlf()) { // Blank line indicates end of headers. if (line_reader.size() == 2) { break; } -#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR } else { +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR // Blank line indicates end of headers. if (line_reader.size() == 1) { break; } line_terminator_len = 1; - } #else - } else { continue; // Skip invalid line. - } #endif + } if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } // Exclude line terminator auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; - parse_header(line_reader.ptr(), end, - [&](std::string &&key, std::string &&val) { - headers.emplace(std::move(key), std::move(val)); - }); + if (!parse_header(line_reader.ptr(), end, + [&](const std::string &key, std::string &val) { + headers.emplace(key, val); + })) { + return false; + } } return true; @@ -3937,8 +4355,19 @@ inline bool read_content_chunked(Stream &strm, T &x, assert(chunk_len == 0); - // Trailer - if (!line_reader.getline()) { return false; } + // NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentiones "The chunked + // transfer coding is complete when a chunk with a chunk-size of zero is + // received, possibly followed by a trailer section, and finally terminated by + // an empty line". https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1 + // + // In '7.1.3. Decoding Chunked', however, the pseudo-code in the section + // does't care for the existence of the final CRLF. In other words, it seems + // to be ok whether the final CRLF exists or not in the chunked data. + // https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1.3 + // + // According to the reference code in RFC 9112, cpp-htpplib now allows + // chuncked transfer coding data without the final CRLF. + if (!line_reader.getline()) { return true; } while (strcmp(line_reader.ptr(), "\r\n") != 0) { if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } @@ -3948,8 +4377,8 @@ inline bool read_content_chunked(Stream &strm, T &x, auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; parse_header(line_reader.ptr(), end, - [&](std::string &&key, std::string &&val) { - x.headers.emplace(std::move(key), std::move(val)); + [&](const std::string &key, const std::string &val) { + x.headers.emplace(key, val); }); if (!line_reader.getline()) { return false; } @@ -3959,8 +4388,8 @@ inline bool read_content_chunked(Stream &strm, T &x, } inline bool is_chunked_transfer_encoding(const Headers &headers) { - return compare_case_ignore( - get_header_value(headers, "Transfer-Encoding", 0, ""), "chunked"); + return case_ignore::equal( + get_header_value(headers, "Transfer-Encoding", "", 0), "chunked"); } template @@ -4026,8 +4455,14 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, } else if (!has_header(x.headers, "Content-Length")) { ret = read_content_without_length(strm, out); } else { - auto len = get_header_value_u64(x.headers, "Content-Length", 0, 0); - if (len > payload_max_length) { + auto is_invalid_value = false; + auto len = get_header_value_u64(x.headers, "Content-Length", + std::numeric_limits::max(), + 0, is_invalid_value); + + if (is_invalid_value) { + ret = false; + } else if (len > payload_max_length) { exceed_payload_max_length = true; skip_content_with_length(strm, len); ret = false; @@ -4042,13 +4477,36 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, } return ret; }); -} // namespace detail +} + +inline ssize_t write_request_line(Stream &strm, const std::string &method, + const std::string &path) { + std::string s = method; + s += " "; + s += path; + s += " HTTP/1.1\r\n"; + return strm.write(s.data(), s.size()); +} + +inline ssize_t write_response_line(Stream &strm, int status) { + std::string s = "HTTP/1.1 "; + s += std::to_string(status); + s += " "; + s += httplib::status_message(status); + s += "\r\n"; + return strm.write(s.data(), s.size()); +} inline ssize_t write_headers(Stream &strm, const Headers &headers) { ssize_t write_len = 0; for (const auto &x : headers) { - auto len = - strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + std::string s; + s = x.first; + s += ": "; + s += x.second; + s += "\r\n"; + + auto len = strm.write(s.data(), s.size()); if (len < 0) { return len; } write_len += len; } @@ -4302,22 +4760,22 @@ inline std::string params_to_query_str(const Params ¶ms) { return query; } -inline void parse_query_text(const std::string &s, Params ¶ms) { +inline void parse_query_text(const char *data, std::size_t size, + Params ¶ms) { std::set cache; - split(s.data(), s.data() + s.size(), '&', [&](const char *b, const char *e) { + split(data, data + size, '&', [&](const char *b, const char *e) { std::string kv(b, e); if (cache.find(kv) != cache.end()) { return; } - cache.insert(kv); + cache.insert(std::move(kv)); std::string key; std::string val; - split(b, e, '=', [&](const char *b2, const char *e2) { - if (key.empty()) { - key.assign(b2, e2); - } else { - val.assign(b2, e2); - } - }); + divide(b, static_cast(e - b), '=', + [&](const char *lhs_data, std::size_t lhs_size, const char *rhs_data, + std::size_t rhs_size) { + key.assign(lhs_data, lhs_size); + val.assign(rhs_data, rhs_size); + }); if (!key.empty()) { params.emplace(decode_url(key, true), decode_url(val, true)); @@ -4325,6 +4783,10 @@ inline void parse_query_text(const std::string &s, Params ¶ms) { }); } +inline void parse_query_text(const std::string &s, Params ¶ms) { + parse_query_text(s.data(), s.size(), params); +} + inline bool parse_multipart_boundary(const std::string &content_type, std::string &boundary) { auto boundary_keyword = "boundary="; @@ -4365,35 +4827,44 @@ inline bool parse_range_header(const std::string &s, Ranges &ranges) { #else inline bool parse_range_header(const std::string &s, Ranges &ranges) try { #endif - static auto re_first_range = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); - std::smatch m; - if (std::regex_match(s, m, re_first_range)) { - auto pos = static_cast(m.position(1)); - auto len = static_cast(m.length(1)); + auto is_valid = [](const std::string &str) { + return std::all_of(str.cbegin(), str.cend(), + [](unsigned char c) { return std::isdigit(c); }); + }; + + if (s.size() > 7 && s.compare(0, 6, "bytes=") == 0) { + const auto pos = static_cast(6); + const auto len = static_cast(s.size() - 6); auto all_valid_ranges = true; split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { if (!all_valid_ranges) { return; } - static auto re_another_range = std::regex(R"(\s*(\d*)-(\d*))"); - std::cmatch cm; - if (std::regex_match(b, e, cm, re_another_range)) { - ssize_t first = -1; - if (!cm.str(1).empty()) { - first = static_cast(std::stoll(cm.str(1))); - } - ssize_t last = -1; - if (!cm.str(2).empty()) { - last = static_cast(std::stoll(cm.str(2))); - } - - if (first != -1 && last != -1 && first > last) { - all_valid_ranges = false; - return; - } - ranges.emplace_back(std::make_pair(first, last)); + const auto it = std::find(b, e, '-'); + if (it == e) { + all_valid_ranges = false; + return; } + + const auto lhs = std::string(b, it); + const auto rhs = std::string(it + 1, e); + if (!is_valid(lhs) || !is_valid(rhs)) { + all_valid_ranges = false; + return; + } + + const auto first = + static_cast(lhs.empty() ? -1 : std::stoll(lhs)); + const auto last = + static_cast(rhs.empty() ? -1 : std::stoll(rhs)); + if ((first == -1 && last == -1) || + (first != -1 && last != -1 && first > last)) { + all_valid_ranges = false; + return; + } + + ranges.emplace_back(first, last); }); - return all_valid_ranges; + return all_valid_ranges && !ranges.empty(); } return false; #ifdef CPPHTTPLIB_NO_EXCEPTIONS @@ -4452,7 +4923,7 @@ public: const auto header = buf_head(pos); if (!parse_header(header.data(), header.data() + header.size(), - [&](std::string &&, std::string &&) {})) { + [&](const std::string &, const std::string &) {})) { is_valid_ = false; return false; } @@ -4562,7 +5033,9 @@ private: const std::string &b) const { if (a.size() < b.size()) { return false; } for (size_t i = 0; i < b.size(); i++) { - if (::tolower(a[i]) != ::tolower(b[i])) { return false; } + if (case_ignore::to_lower(a[i]) != case_ignore::to_lower(b[i])) { + return false; + } } return true; } @@ -4645,16 +5118,6 @@ private: size_t buf_epos_ = 0; }; -inline std::string to_lower(const char *beg, const char *end) { - std::string out; - auto it = beg; - while (it != end) { - out += static_cast(::tolower(*it)); - it++; - } - return out; -} - inline std::string random_string(size_t length) { static const char data[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; @@ -4768,7 +5231,18 @@ inline bool range_error(Request &req, Response &res) { last_pos = contant_len - 1; } - if (last_pos == -1) { last_pos = contant_len - 1; } + // NOTE: RFC-9110 '14.1.2. Byte Ranges': + // A client can limit the number of bytes requested without knowing the + // size of the selected representation. If the last-pos value is absent, + // or if the value is greater than or equal to the current length of the + // representation data, the byte range is interpreted as the remainder of + // the representation (i.e., the server replaces the value of last-pos + // with a value that is one less than the current length of the selected + // representation). + // https://www.rfc-editor.org/rfc/rfc9110.html#section-14.1.2-6 + if (last_pos == -1 || last_pos >= contant_len) { + last_pos = contant_len - 1; + } // Range must be within content length if (!(0 <= first_pos && first_pos <= last_pos && @@ -4795,12 +5269,11 @@ inline bool range_error(Request &req, Response &res) { inline std::pair get_range_offset_and_length(Range r, size_t content_length) { - (void)(content_length); // patch to get rid of "unused parameter" on release build assert(r.first != -1 && r.second != -1); assert(0 <= r.first && r.first < static_cast(content_length)); assert(r.first <= r.second && r.second < static_cast(content_length)); - + (void)(content_length); return std::make_pair(r.first, static_cast(r.second - r.first) + 1); } @@ -5230,6 +5703,7 @@ inline void hosted_at(const std::string &hostname, #endif return; } + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); for (auto rp = result; rp; rp = rp->ai_next) { const auto &addr = @@ -5241,8 +5715,6 @@ inline void hosted_at(const std::string &hostname, addrs.push_back(ip); } } - - freeaddrinfo(result); } inline std::string append_query_params(const std::string &path, @@ -5291,8 +5763,8 @@ inline bool Request::has_header(const std::string &key) const { } inline std::string Request::get_header_value(const std::string &key, - size_t id) const { - return detail::get_header_value(headers, key, id, ""); + const char *def, size_t id) const { + return detail::get_header_value(headers, key, def, id); } inline size_t Request::get_header_value_count(const std::string &key) const { @@ -5302,7 +5774,8 @@ inline size_t Request::get_header_value_count(const std::string &key) const { inline void Request::set_header(const std::string &key, const std::string &val) { - if (!detail::has_crlf(key) && !detail::has_crlf(val)) { + if (detail::fields::is_field_name(key) && + detail::fields::is_field_value(val)) { headers.emplace(key, val); } } @@ -5356,8 +5829,9 @@ inline bool Response::has_header(const std::string &key) const { } inline std::string Response::get_header_value(const std::string &key, + const char *def, size_t id) const { - return detail::get_header_value(headers, key, id, ""); + return detail::get_header_value(headers, key, def, id); } inline size_t Response::get_header_value_count(const std::string &key) const { @@ -5367,13 +5841,14 @@ inline size_t Response::get_header_value_count(const std::string &key) const { inline void Response::set_header(const std::string &key, const std::string &val) { - if (!detail::has_crlf(key) && !detail::has_crlf(val)) { + if (detail::fields::is_field_name(key) && + detail::fields::is_field_value(val)) { headers.emplace(key, val); } } inline void Response::set_redirect(const std::string &url, int stat) { - if (!detail::has_crlf(url)) { + if (detail::fields::is_field_value(url)) { set_header("Location", url); if (300 <= stat && stat < 400) { this->status = stat; @@ -5436,14 +5911,25 @@ inline void Response::set_chunked_content_provider( is_chunked_content_provider_ = true; } +inline void Response::set_file_content(const std::string &path, + const std::string &content_type) { + file_content_path_ = path; + file_content_content_type_ = content_type; +} + +inline void Response::set_file_content(const std::string &path) { + file_content_path_ = path; +} + // Result implementation inline bool Result::has_request_header(const std::string &key) const { return request_headers_.find(key) != request_headers_.end(); } inline std::string Result::get_request_header_value(const std::string &key, + const char *def, size_t id) const { - return detail::get_header_value(request_headers_, key, id, ""); + return detail::get_header_value(request_headers_, key, def, id); } inline size_t @@ -5584,6 +6070,8 @@ inline socket_t BufferStream::socket() const { return 0; } inline const std::string &BufferStream::get_buffer() const { return buffer; } inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) { + static constexpr char marker[] = "/:"; + // One past the last ending position of a path param substring std::size_t last_param_end = 0; @@ -5596,13 +6084,14 @@ inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) { #endif while (true) { - const auto marker_pos = pattern.find(marker, last_param_end); + const auto marker_pos = pattern.find( + marker, last_param_end == 0 ? last_param_end : last_param_end - 1); if (marker_pos == std::string::npos) { break; } static_fragments_.push_back( - pattern.substr(last_param_end, marker_pos - last_param_end)); + pattern.substr(last_param_end, marker_pos - last_param_end + 1)); - const auto param_name_start = marker_pos + 1; + const auto param_name_start = marker_pos + 2; auto sep_pos = pattern.find(separator, param_name_start); if (sep_pos == std::string::npos) { sep_pos = pattern.length(); } @@ -5664,7 +6153,7 @@ inline bool PathParamsMatcher::match(Request &request) const { request.path_params.emplace( param_name, request.path.substr(starting_pos, sep_pos - starting_pos)); - // Mark everythin up to '/' as matched + // Mark everything up to '/' as matched starting_pos = sep_pos + 1; } // Returns false if the path is longer than the pattern @@ -5763,7 +6252,8 @@ inline bool Server::set_base_dir(const std::string &dir, inline bool Server::set_mount_point(const std::string &mount_point, const std::string &dir, Headers headers) { - if (detail::is_dir(dir)) { + detail::FileStat stat(dir); + if (stat.is_dir()) { std::string mnt = !mount_point.empty() ? mount_point : "/"; if (!mnt.empty() && mnt[0] == '/') { base_dirs_.push_back({mnt, dir, std::move(headers)}); @@ -5800,12 +6290,14 @@ inline Server &Server::set_file_request_handler(Handler handler) { return *this; } -inline Server &Server::set_error_handler(HandlerWithResponse handler) { +inline Server &Server::set_error_handler_core(HandlerWithResponse handler, + std::true_type) { error_handler_ = std::move(handler); return *this; } -inline Server &Server::set_error_handler(Handler handler) { +inline Server &Server::set_error_handler_core(Handler handler, + std::false_type) { error_handler_ = [handler](const Request &req, Response &res) { handler(req, res); return HandlerResponse::Handled; @@ -5849,6 +6341,11 @@ inline Server &Server::set_tcp_nodelay(bool on) { return *this; } +inline Server &Server::set_ipv6_v6only(bool on) { + ipv6_v6only_ = on; + return *this; +} + inline Server &Server::set_socket_options(SocketOptions socket_options) { socket_options_ = std::move(socket_options); return *this; @@ -5900,27 +6397,27 @@ inline Server &Server::set_payload_max_length(size_t length) { inline bool Server::bind_to_port(const std::string &host, int port, int socket_flags) { - return bind_internal(host, port, socket_flags) >= 0; + auto ret = bind_internal(host, port, socket_flags); + if (ret == -1) { is_decommisioned = true; } + return ret >= 0; } inline int Server::bind_to_any_port(const std::string &host, int socket_flags) { - return bind_internal(host, 0, socket_flags); + auto ret = bind_internal(host, 0, socket_flags); + if (ret == -1) { is_decommisioned = true; } + return ret; } -inline bool Server::listen_after_bind() { - auto se = detail::scope_exit([&]() { done_ = true; }); - return listen_internal(); -} +inline bool Server::listen_after_bind() { return listen_internal(); } inline bool Server::listen(const std::string &host, int port, int socket_flags) { - auto se = detail::scope_exit([&]() { done_ = true; }); return bind_to_port(host, port, socket_flags) && listen_internal(); } inline bool Server::is_running() const { return is_running_; } inline void Server::wait_until_ready() const { - while (!is_running() && !done_) { + while (!is_running_ && !is_decommisioned) { std::this_thread::sleep_for(std::chrono::milliseconds{1}); } } @@ -5932,8 +6429,11 @@ inline void Server::stop() { detail::shutdown_socket(sock); detail::close_socket(sock); } + is_decommisioned = false; } +inline void Server::decommission() { is_decommisioned = true; } + inline bool Server::parse_request_line(const char *s, Request &req) const { auto len = strlen(s); if (len < 2 || s[len - 2] != '\r' || s[len - 1] != '\n') { return false; } @@ -5972,26 +6472,13 @@ inline bool Server::parse_request_line(const char *s, Request &req) const { } } - size_t count = 0; - - detail::split(req.target.data(), req.target.data() + req.target.size(), '?', - 2, [&](const char *b, const char *e) { - switch (count) { - case 0: - req.path = detail::decode_url(std::string(b, e), false); - break; - case 1: { - if (e - b > 0) { - detail::parse_query_text(std::string(b, e), req.params); - } - break; - } - default: break; - } - count++; - }); - - if (count > 2) { return false; } + detail::divide(req.target, '?', + [&](const char *lhs_data, std::size_t lhs_size, + const char *rhs_data, std::size_t rhs_size) { + req.path = detail::decode_url( + std::string(lhs_data, lhs_size), false); + detail::parse_query_text(rhs_data, rhs_size, req.params); + }); } return true; @@ -6030,23 +6517,24 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection, if (close_connection || req.get_header_value("Connection") == "close") { res.set_header("Connection", "close"); } else { - std::stringstream ss; - ss << "timeout=" << keep_alive_timeout_sec_ - << ", max=" << keep_alive_max_count_; - res.set_header("Keep-Alive", ss.str()); + std::string s = "timeout="; + s += std::to_string(keep_alive_timeout_sec_); + s += ", max="; + s += std::to_string(keep_alive_max_count_); + res.set_header("Keep-Alive", s); } - if (!res.has_header("Content-Type") && - (!res.body.empty() || res.content_length_ > 0 || res.content_provider_)) { + if ((!res.body.empty() || res.content_length_ > 0 || res.content_provider_) && + !res.has_header("Content-Type")) { res.set_header("Content-Type", "text/plain"); } - if (!res.has_header("Content-Length") && res.body.empty() && - !res.content_length_ && !res.content_provider_) { + if (res.body.empty() && !res.content_length_ && !res.content_provider_ && + !res.has_header("Content-Length")) { res.set_header("Content-Length", "0"); } - if (!res.has_header("Accept-Ranges") && req.method == "HEAD") { + if (req.method == "HEAD" && !res.has_header("Accept-Ranges")) { res.set_header("Accept-Ranges", "bytes"); } @@ -6055,12 +6543,7 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection, // Response line and headers { detail::BufferStream bstrm; - - if (!bstrm.write_format("HTTP/1.1 %d %s\r\n", res.status, - status_message(res.status))) { - return false; - } - + if (!detail::write_response_line(bstrm, res.status)) { return false; } if (!header_writer_(bstrm, res.headers)) { return false; } // Flush buffer @@ -6254,7 +6737,14 @@ inline bool Server::handle_file_request(const Request &req, Response &res, auto path = entry.base_dir + sub_path; if (path.back() == '/') { path += "index.html"; } - if (detail::is_file(path)) { + detail::FileStat stat(path); + + if (stat.is_dir()) { + res.set_redirect(sub_path + "/", StatusCode::MovedPermanently_301); + return true; + } + + if (stat.is_file()) { for (const auto &kv : entry.headers) { res.set_header(kv.first, kv.second); } @@ -6289,8 +6779,8 @@ Server::create_server_socket(const std::string &host, int port, SocketOptions socket_options) const { return detail::create_socket( host, std::string(), port, address_family_, socket_flags, tcp_nodelay_, - std::move(socket_options), - [](socket_t sock, struct addrinfo &ai) -> bool { + ipv6_v6only_, std::move(socket_options), + [](socket_t sock, struct addrinfo &ai, bool & /*quit*/) -> bool { if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { return false; } @@ -6301,6 +6791,8 @@ Server::create_server_socket(const std::string &host, int port, inline int Server::bind_internal(const std::string &host, int port, int socket_flags) { + if (is_decommisioned) { return -1; } + if (!is_valid()) { return -1; } svr_sock_ = create_server_socket(host, port, socket_flags, socket_options_); @@ -6326,6 +6818,8 @@ inline int Server::bind_internal(const std::string &host, int port, } inline bool Server::listen_internal() { + if (is_decommisioned) { return false; } + auto ret = true; is_running_ = true; auto se = detail::scope_exit([&]() { is_running_ = false; }); @@ -6346,13 +6840,22 @@ inline bool Server::listen_internal() { #ifndef _WIN32 } #endif + +#if defined _WIN32 + // sockets conneced via WASAccept inherit flags NO_HANDLE_INHERIT, + // OVERLAPPED + socket_t sock = WSAAccept(svr_sock_, nullptr, nullptr, nullptr, 0); +#elif defined SOCK_CLOEXEC + socket_t sock = accept4(svr_sock_, nullptr, nullptr, SOCK_CLOEXEC); +#else socket_t sock = accept(svr_sock_, nullptr, nullptr); +#endif if (sock == INVALID_SOCKET) { if (errno == EMFILE) { // The per-process limit of open file descriptors has been reached. // Try to accept new connections after a short sleep. - std::this_thread::sleep_for(std::chrono::milliseconds(1)); + std::this_thread::sleep_for(std::chrono::microseconds{1}); continue; } else if (errno == EINTR || errno == EAGAIN) { continue; @@ -6406,6 +6909,7 @@ inline bool Server::listen_internal() { task_queue->shutdown(); } + is_decommisioned = !ret; return ret; } @@ -6503,7 +7007,7 @@ inline bool Server::dispatch_request(Request &req, Response &res, inline void Server::apply_ranges(const Request &req, Response &res, std::string &content_type, std::string &boundary) const { - if (req.ranges.size() > 1) { + if (req.ranges.size() > 1 && res.status == StatusCode::PartialContent_206) { auto it = res.headers.find("Content-Type"); if (it != res.headers.end()) { content_type = it->second; @@ -6521,7 +7025,7 @@ inline void Server::apply_ranges(const Request &req, Response &res, if (res.body.empty()) { if (res.content_length_ > 0) { size_t length = 0; - if (req.ranges.empty()) { + if (req.ranges.empty() || res.status != StatusCode::PartialContent_206) { length = res.content_length_; } else if (req.ranges.size() == 1) { auto offset_and_length = detail::get_range_offset_and_length( @@ -6550,7 +7054,7 @@ inline void Server::apply_ranges(const Request &req, Response &res, } } } else { - if (req.ranges.empty()) { + if (req.ranges.empty() || res.status != StatusCode::PartialContent_206) { ; } else if (req.ranges.size() == 1) { auto offset_and_length = @@ -6621,7 +7125,9 @@ inline bool Server::dispatch_request_for_content_reader( } inline bool -Server::process_request(Stream &strm, bool close_connection, +Server::process_request(Stream &strm, const std::string &remote_addr, + int remote_port, const std::string &local_addr, + int local_port, bool close_connection, bool &connection_closed, const std::function &setup_request) { std::array buf{}; @@ -6675,11 +7181,13 @@ Server::process_request(Stream &strm, bool close_connection, connection_closed = true; } - strm.get_remote_ip_and_port(req.remote_addr, req.remote_port); + req.remote_addr = remote_addr; + req.remote_port = remote_port; req.set_header("REMOTE_ADDR", req.remote_addr); req.set_header("REMOTE_PORT", std::to_string(req.remote_port)); - strm.get_local_ip_and_port(req.local_addr, req.local_port); + req.local_addr = local_addr; + req.local_port = local_port; req.set_header("LOCAL_ADDR", req.local_addr); req.set_header("LOCAL_PORT", std::to_string(req.local_port)); @@ -6701,13 +7209,20 @@ Server::process_request(Stream &strm, bool close_connection, switch (status) { case StatusCode::Continue_100: case StatusCode::ExpectationFailed_417: - strm.write_format("HTTP/1.1 %d %s\r\n\r\n", status, - status_message(status)); + detail::write_response_line(strm, status); + strm.write("\r\n"); break; - default: return write_response(strm, close_connection, req, res); + default: + connection_closed = true; + return write_response(strm, true, req, res); } } + // Setup `is_connection_closed` method + req.is_connection_closed = [&]() { + return !detail::is_socket_alive(strm.socket()); + }; + // Routing auto routed = false; #ifdef CPPHTTPLIB_NO_EXCEPTIONS @@ -6750,6 +7265,32 @@ Server::process_request(Stream &strm, bool close_connection, : StatusCode::PartialContent_206; } + // Serve file content by using a content provider + if (!res.file_content_path_.empty()) { + const auto &path = res.file_content_path_; + auto mm = std::make_shared(path.c_str()); + if (!mm->is_open()) { + res.body.clear(); + res.content_length_ = 0; + res.content_provider_ = nullptr; + res.status = StatusCode::NotFound_404; + return write_response(strm, close_connection, req, res); + } + + auto content_type = res.file_content_content_type_; + if (content_type.empty()) { + content_type = detail::find_content_type( + path, file_extension_and_mimetype_map_, default_file_mimetype_); + } + + res.set_content_provider( + mm->size(), content_type, + [mm](size_t offset, size_t length, DataSink &sink) -> bool { + sink.write(mm->data() + offset, length); + return true; + }); + } + if (detail::range_error(req, res)) { res.body.clear(); res.content_length_ = 0; @@ -6769,12 +7310,21 @@ Server::process_request(Stream &strm, bool close_connection, inline bool Server::is_valid() const { return true; } inline bool Server::process_and_close_socket(socket_t sock) { + std::string remote_addr; + int remote_port = 0; + detail::get_remote_ip_and_port(sock, remote_addr, remote_port); + + std::string local_addr; + int local_port = 0; + detail::get_local_ip_and_port(sock, local_addr, local_port); + auto ret = detail::process_server_socket( svr_sock_, sock, keep_alive_max_count_, keep_alive_timeout_sec_, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, - [this](Stream &strm, bool close_connection, bool &connection_closed) { - return process_request(strm, close_connection, connection_closed, + [&](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, remote_addr, remote_port, local_addr, + local_port, close_connection, connection_closed, nullptr); }); @@ -6793,8 +7343,8 @@ inline ClientImpl::ClientImpl(const std::string &host, int port) inline ClientImpl::ClientImpl(const std::string &host, int port, const std::string &client_cert_path, const std::string &client_key_path) - : host_(host), port_(port), - host_and_port_(adjust_host_string(host) + ":" + std::to_string(port)), + : host_(detail::escape_abstract_namespace_unix_domain(host)), port_(port), + host_and_port_(adjust_host_string(host_) + ":" + std::to_string(port)), client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} inline ClientImpl::~ClientImpl() { @@ -6825,6 +7375,7 @@ inline void ClientImpl::copy_settings(const ClientImpl &rhs) { url_encode_ = rhs.url_encode_; address_family_ = rhs.address_family_; tcp_nodelay_ = rhs.tcp_nodelay_; + ipv6_v6only_ = rhs.ipv6_v6only_; socket_options_ = rhs.socket_options_; compress_ = rhs.compress_; decompress_ = rhs.decompress_; @@ -6845,6 +7396,8 @@ inline void ClientImpl::copy_settings(const ClientImpl &rhs) { #endif #ifdef CPPHTTPLIB_OPENSSL_SUPPORT server_certificate_verification_ = rhs.server_certificate_verification_; + server_hostname_verification_ = rhs.server_hostname_verification_; + server_certificate_verifier_ = rhs.server_certificate_verifier_; #endif logger_ = rhs.logger_; } @@ -6853,9 +7406,9 @@ inline socket_t ClientImpl::create_client_socket(Error &error) const { if (!proxy_host_.empty() && proxy_port_ != -1) { return detail::create_client_socket( proxy_host_, std::string(), proxy_port_, address_family_, tcp_nodelay_, - socket_options_, connection_timeout_sec_, connection_timeout_usec_, - read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, - write_timeout_usec_, interface_, error); + ipv6_v6only_, socket_options_, connection_timeout_sec_, + connection_timeout_usec_, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, interface_, error); } // Check is custom IP specified for host_ @@ -6864,10 +7417,10 @@ inline socket_t ClientImpl::create_client_socket(Error &error) const { if (it != addr_map_.end()) { ip = it->second; } return detail::create_client_socket( - host_, ip, port_, address_family_, tcp_nodelay_, socket_options_, - connection_timeout_sec_, connection_timeout_usec_, read_timeout_sec_, - read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, interface_, - error); + host_, ip, port_, address_family_, tcp_nodelay_, ipv6_v6only_, + socket_options_, connection_timeout_sec_, connection_timeout_usec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, interface_, error); } inline bool ClientImpl::create_and_connect_socket(Socket &socket, @@ -6956,6 +7509,18 @@ inline bool ClientImpl::send(Request &req, Response &res, Error &error) { return ret; } +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline bool ClientImpl::is_ssl_peer_could_be_closed(SSL *ssl) const { + detail::set_nonblocking(socket_.sock, true); + auto se = detail::scope_exit( + [&]() { detail::set_nonblocking(socket_.sock, false); }); + + char buf[1]; + return !SSL_peek(ssl, buf, 1) && + SSL_get_error(ssl, 0) == SSL_ERROR_ZERO_RETURN; +} +#endif + inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { { std::lock_guard guard(socket_mutex_); @@ -6967,6 +7532,13 @@ inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { auto is_alive = false; if (socket_.is_open()) { is_alive = detail::is_socket_alive(socket_.sock); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_alive && is_ssl()) { + if (is_ssl_peer_could_be_closed(socket_.ssl)) { is_alive = false; } + } +#endif + if (!is_alive) { // Attempt to avoid sigpipe by shutting down nongracefully if it seems // like the other side has already closed the connection Also, there @@ -7144,7 +7716,7 @@ inline bool ClientImpl::redirect(Request &req, Response &res, Error &error) { if (location.empty()) { return false; } const static std::regex re( - R"((?:(https?):)?(?://(?:\[([\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)"); + R"((?:(https?):)?(?://(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)"); std::smatch m; if (!std::regex_match(location, m, re)) { return false; } @@ -7243,12 +7815,26 @@ inline bool ClientImpl::write_request(Stream &strm, Request &req, if (!req.has_header("Accept")) { req.set_header("Accept", "*/*"); } -#ifndef CPPHTTPLIB_NO_DEFAULT_USER_AGENT - if (!req.has_header("User-Agent")) { - auto agent = std::string("cpp-httplib/") + CPPHTTPLIB_VERSION; - req.set_header("User-Agent", agent); - } + if (!req.content_receiver) { + if (!req.has_header("Accept-Encoding")) { + std::string accept_encoding; +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + accept_encoding = "br"; #endif +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (!accept_encoding.empty()) { accept_encoding += ", "; } + accept_encoding += "gzip, deflate"; +#endif + req.set_header("Accept-Encoding", accept_encoding); + } + +#ifndef CPPHTTPLIB_NO_DEFAULT_USER_AGENT + if (!req.has_header("User-Agent")) { + auto agent = std::string("cpp-httplib/") + CPPHTTPLIB_VERSION; + req.set_header("User-Agent", agent); + } +#endif + }; if (req.body.empty()) { if (req.content_provider_) { @@ -7308,8 +7894,14 @@ inline bool ClientImpl::write_request(Stream &strm, Request &req, { detail::BufferStream bstrm; - const auto &path = url_encode_ ? detail::encode_url(req.path) : req.path; - bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); + const auto &path_with_query = + req.params.empty() ? req.path + : append_query_params(req.path, req.params); + + const auto &path = + url_encode_ ? detail::encode_url(path_with_query) : path_with_query; + + detail::write_request_line(bstrm, req.method, path); header_writer_(bstrm, req.headers); @@ -7417,11 +8009,12 @@ inline Result ClientImpl::send_with_content_provider( const std::string &method, const std::string &path, const Headers &headers, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type) { + const std::string &content_type, Progress progress) { Request req; req.method = method; req.headers = headers; req.path = path; + req.progress = progress; auto error = Error::Success; @@ -7448,9 +8041,7 @@ inline bool ClientImpl::process_request(Stream &strm, Request &req, if (is_ssl()) { auto is_proxy_enabled = !proxy_host_.empty() && proxy_port_ != -1; if (!is_proxy_enabled) { - char buf[1]; - if (SSL_peek(socket_.ssl, buf, 1) == 0 && - SSL_get_error(socket_.ssl, 0) == SSL_ERROR_ZERO_RETURN) { + if (is_ssl_peer_could_be_closed(socket_.ssl)) { error = Error::SSLPeerCouldBeClosed_; return false; } @@ -7468,7 +8059,9 @@ inline bool ClientImpl::process_request(Stream &strm, Request &req, // Body if ((res.status != StatusCode::NoContent_204) && req.method != "HEAD" && req.method != "CONNECT") { - auto redirect = 300 < res.status && res.status < 400 && follow_location_; + auto redirect = 300 < res.status && res.status < 400 && + res.status != StatusCode::NotModified_304 && + follow_location_; if (req.response_handler && !redirect) { if (!req.response_handler(res)) { @@ -7489,9 +8082,7 @@ inline bool ClientImpl::process_request(Stream &strm, Request &req, : static_cast( [&](const char *buf, size_t n, uint64_t /*off*/, uint64_t /*len*/) { - if (res.body.size() + n > res.body.max_size()) { - return false; - } + assert(res.body.size() + n <= res.body.max_size()); res.body.append(buf, n); return true; }); @@ -7503,12 +8094,25 @@ inline bool ClientImpl::process_request(Stream &strm, Request &req, return ret; }; - int dummy_status; - if (!detail::read_content(strm, res, (std::numeric_limits::max)(), - dummy_status, std::move(progress), std::move(out), - decompress_)) { - if (error != Error::Canceled) { error = Error::Read; } - return false; + if (res.has_header("Content-Length")) { + if (!req.content_receiver) { + auto len = res.get_header_value_u64("Content-Length"); + if (len > res.body.max_size()) { + error = Error::Read; + return false; + } + res.body.reserve(static_cast(len)); + } + } + + if (res.status != StatusCode::NotModified_304) { + int dummy_status; + if (!detail::read_content(strm, res, (std::numeric_limits::max)(), + dummy_status, std::move(progress), + std::move(out), decompress_)) { + if (error != Error::Canceled) { error = Error::Read; } + return false; + } } } @@ -7717,14 +8321,22 @@ inline Result ClientImpl::Post(const std::string &path, inline Result ClientImpl::Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type) { - return Post(path, Headers(), body, content_length, content_type); + return Post(path, Headers(), body, content_length, content_type, nullptr); } inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type) { return send_with_content_provider("POST", path, headers, body, content_length, - nullptr, nullptr, content_type); + nullptr, nullptr, content_type, nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + Progress progress) { + return send_with_content_provider("POST", path, headers, body, content_length, + nullptr, nullptr, content_type, progress); } inline Result ClientImpl::Post(const std::string &path, const std::string &body, @@ -7732,12 +8344,27 @@ inline Result ClientImpl::Post(const std::string &path, const std::string &body, return Post(path, Headers(), body, content_type); } +inline Result ClientImpl::Post(const std::string &path, const std::string &body, + const std::string &content_type, + Progress progress) { + return Post(path, Headers(), body, content_type, progress); +} + inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type) { return send_with_content_provider("POST", path, headers, body.data(), - body.size(), nullptr, nullptr, - content_type); + body.size(), nullptr, nullptr, content_type, + nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { + return send_with_content_provider("POST", path, headers, body.data(), + body.size(), nullptr, nullptr, content_type, + progress); } inline Result ClientImpl::Post(const std::string &path, const Params ¶ms) { @@ -7763,14 +8390,15 @@ inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const std::string &content_type) { return send_with_content_provider("POST", path, headers, nullptr, content_length, std::move(content_provider), - nullptr, content_type); + nullptr, content_type, nullptr); } inline Result ClientImpl::Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type) { return send_with_content_provider("POST", path, headers, nullptr, 0, nullptr, - std::move(content_provider), content_type); + std::move(content_provider), content_type, + nullptr); } inline Result ClientImpl::Post(const std::string &path, const Headers &headers, @@ -7779,6 +8407,13 @@ inline Result ClientImpl::Post(const std::string &path, const Headers &headers, return Post(path, headers, query, "application/x-www-form-urlencoded"); } +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress) { + auto query = detail::params_to_query_str(params); + return Post(path, headers, query, "application/x-www-form-urlencoded", + progress); +} + inline Result ClientImpl::Post(const std::string &path, const MultipartFormDataItems &items) { return Post(path, Headers(), items); @@ -7816,7 +8451,7 @@ ClientImpl::Post(const std::string &path, const Headers &headers, return send_with_content_provider( "POST", path, headers, nullptr, 0, nullptr, get_multipart_content_provider(boundary, items, provider_items), - content_type); + content_type, nullptr); } inline Result ClientImpl::Put(const std::string &path) { @@ -7833,7 +8468,15 @@ inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type) { return send_with_content_provider("PUT", path, headers, body, content_length, - nullptr, nullptr, content_type); + nullptr, nullptr, content_type, nullptr); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + Progress progress) { + return send_with_content_provider("PUT", path, headers, body, content_length, + nullptr, nullptr, content_type, progress); } inline Result ClientImpl::Put(const std::string &path, const std::string &body, @@ -7841,12 +8484,27 @@ inline Result ClientImpl::Put(const std::string &path, const std::string &body, return Put(path, Headers(), body, content_type); } +inline Result ClientImpl::Put(const std::string &path, const std::string &body, + const std::string &content_type, + Progress progress) { + return Put(path, Headers(), body, content_type, progress); +} + inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type) { return send_with_content_provider("PUT", path, headers, body.data(), - body.size(), nullptr, nullptr, - content_type); + body.size(), nullptr, nullptr, content_type, + nullptr); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { + return send_with_content_provider("PUT", path, headers, body.data(), + body.size(), nullptr, nullptr, content_type, + progress); } inline Result ClientImpl::Put(const std::string &path, size_t content_length, @@ -7868,14 +8526,15 @@ inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const std::string &content_type) { return send_with_content_provider("PUT", path, headers, nullptr, content_length, std::move(content_provider), - nullptr, content_type); + nullptr, content_type, nullptr); } inline Result ClientImpl::Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type) { return send_with_content_provider("PUT", path, headers, nullptr, 0, nullptr, - std::move(content_provider), content_type); + std::move(content_provider), content_type, + nullptr); } inline Result ClientImpl::Put(const std::string &path, const Params ¶ms) { @@ -7888,6 +8547,13 @@ inline Result ClientImpl::Put(const std::string &path, const Headers &headers, return Put(path, headers, query, "application/x-www-form-urlencoded"); } +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress) { + auto query = detail::params_to_query_str(params); + return Put(path, headers, query, "application/x-www-form-urlencoded", + progress); +} + inline Result ClientImpl::Put(const std::string &path, const MultipartFormDataItems &items) { return Put(path, Headers(), items); @@ -7925,7 +8591,7 @@ ClientImpl::Put(const std::string &path, const Headers &headers, return send_with_content_provider( "PUT", path, headers, nullptr, 0, nullptr, get_multipart_content_provider(boundary, items, provider_items), - content_type); + content_type, nullptr); } inline Result ClientImpl::Patch(const std::string &path) { return Patch(path, std::string(), std::string()); @@ -7937,12 +8603,26 @@ inline Result ClientImpl::Patch(const std::string &path, const char *body, return Patch(path, Headers(), body, content_length, content_type); } +inline Result ClientImpl::Patch(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + Progress progress) { + return Patch(path, Headers(), body, content_length, content_type, progress); +} + inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type) { + return Patch(path, headers, body, content_length, content_type, nullptr); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + Progress progress) { return send_with_content_provider("PATCH", path, headers, body, content_length, nullptr, nullptr, - content_type); + content_type, progress); } inline Result ClientImpl::Patch(const std::string &path, @@ -7951,12 +8631,26 @@ inline Result ClientImpl::Patch(const std::string &path, return Patch(path, Headers(), body, content_type); } +inline Result ClientImpl::Patch(const std::string &path, + const std::string &body, + const std::string &content_type, + Progress progress) { + return Patch(path, Headers(), body, content_type, progress); +} + inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type) { + return Patch(path, headers, body, content_type, nullptr); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { return send_with_content_provider("PATCH", path, headers, body.data(), - body.size(), nullptr, nullptr, - content_type); + body.size(), nullptr, nullptr, content_type, + progress); } inline Result ClientImpl::Patch(const std::string &path, size_t content_length, @@ -7978,14 +8672,15 @@ inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, const std::string &content_type) { return send_with_content_provider("PATCH", path, headers, nullptr, content_length, std::move(content_provider), - nullptr, content_type); + nullptr, content_type, nullptr); } inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type) { return send_with_content_provider("PATCH", path, headers, nullptr, 0, nullptr, - std::move(content_provider), content_type); + std::move(content_provider), content_type, + nullptr); } inline Result ClientImpl::Delete(const std::string &path) { @@ -8003,14 +8698,30 @@ inline Result ClientImpl::Delete(const std::string &path, const char *body, return Delete(path, Headers(), body, content_length, content_type); } +inline Result ClientImpl::Delete(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + Progress progress) { + return Delete(path, Headers(), body, content_length, content_type, progress); +} + inline Result ClientImpl::Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type) { + return Delete(path, headers, body, content_length, content_type, nullptr); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, const char *body, + size_t content_length, + const std::string &content_type, + Progress progress) { Request req; req.method = "DELETE"; req.headers = headers; req.path = path; + req.progress = progress; if (!content_type.empty()) { req.set_header("Content-Type", content_type); } req.body.assign(body, content_length); @@ -8024,6 +8735,14 @@ inline Result ClientImpl::Delete(const std::string &path, return Delete(path, Headers(), body.data(), body.size(), content_type); } +inline Result ClientImpl::Delete(const std::string &path, + const std::string &body, + const std::string &content_type, + Progress progress) { + return Delete(path, Headers(), body.data(), body.size(), content_type, + progress); +} + inline Result ClientImpl::Delete(const std::string &path, const Headers &headers, const std::string &body, @@ -8031,6 +8750,15 @@ inline Result ClientImpl::Delete(const std::string &path, return Delete(path, headers, body.data(), body.size(), content_type); } +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { + return Delete(path, headers, body.data(), body.size(), content_type, + progress); +} + inline Result ClientImpl::Options(const std::string &path) { return Options(path, Headers()); } @@ -8138,6 +8866,8 @@ inline void ClientImpl::set_address_family(int family) { inline void ClientImpl::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; } +inline void ClientImpl::set_ipv6_v6only(bool on) { ipv6_v6only_ = on; } + inline void ClientImpl::set_socket_options(SocketOptions socket_options) { socket_options_ = std::move(socket_options); } @@ -8187,13 +8917,11 @@ inline void ClientImpl::set_ca_cert_store(X509_STORE *ca_cert_store) { inline X509_STORE *ClientImpl::create_ca_cert_store(const char *ca_cert, std::size_t size) const { auto mem = BIO_new_mem_buf(ca_cert, static_cast(size)); + auto se = detail::scope_exit([&] { BIO_free_all(mem); }); if (!mem) { return nullptr; } auto inf = PEM_X509_INFO_read_bio(mem, nullptr, nullptr, nullptr); - if (!inf) { - BIO_free_all(mem); - return nullptr; - } + if (!inf) { return nullptr; } auto cts = X509_STORE_new(); if (cts) { @@ -8207,13 +8935,21 @@ inline X509_STORE *ClientImpl::create_ca_cert_store(const char *ca_cert, } sk_X509_INFO_pop_free(inf, X509_INFO_free); - BIO_free_all(mem); return cts; } inline void ClientImpl::enable_server_certificate_verification(bool enabled) { server_certificate_verification_ = enabled; } + +inline void ClientImpl::enable_server_hostname_verification(bool enabled) { + server_hostname_verification_ = enabled; +} + +inline void ClientImpl::set_server_certificate_verifier( + std::function verifier) { + server_certificate_verifier_ = verifier; +} #endif inline void ClientImpl::set_logger(Logger logger) { @@ -8257,13 +8993,30 @@ inline SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex, return ssl; } -inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, +inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, socket_t sock, bool shutdown_gracefully) { // sometimes we may want to skip this to try to avoid SIGPIPE if we know // the remote has closed the network connection // Note that it is not always possible to avoid SIGPIPE, this is merely a // best-efforts. - if (shutdown_gracefully) { SSL_shutdown(ssl); } + if (shutdown_gracefully) { +#ifdef _WIN32 + (void)(sock); + SSL_shutdown(ssl); +#else + timeval tv; + tv.tv_sec = 1; + tv.tv_usec = 0; + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, + reinterpret_cast(&tv), sizeof(tv)); + + auto ret = SSL_shutdown(ssl); + while (ret == 0) { + std::this_thread::sleep_for(std::chrono::milliseconds{100}); + ret = SSL_shutdown(ssl); + } +#endif + } std::lock_guard guard(ctx_mutex); SSL_free(ssl); @@ -8366,7 +9119,7 @@ inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { if (SSL_pending(ssl_) > 0) { return SSL_read(ssl_, ptr, static_cast(size)); } else if (is_readable()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); + std::this_thread::sleep_for(std::chrono::microseconds{10}); ret = SSL_read(ssl_, ptr, static_cast(size)); if (ret >= 0) { return ret; } err = SSL_get_error(ssl_, ret); @@ -8397,7 +9150,7 @@ inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { #endif if (is_writable()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); + std::this_thread::sleep_for(std::chrono::microseconds{10}); ret = SSL_write(ssl_, ptr, static_cast(handle_size)); if (ret >= 0) { return ret; } err = SSL_get_error(ssl_, ret); @@ -8439,7 +9192,7 @@ inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, SSL_OP_NO_COMPRESSION | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); - SSL_CTX_set_min_proto_version(ctx_, TLS1_1_VERSION); + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); if (private_key_password != nullptr && (private_key_password[0] != '\0')) { SSL_CTX_set_default_passwd_cb_userdata( @@ -8449,7 +9202,8 @@ inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != - 1) { + 1 || + SSL_CTX_check_private_key(ctx_) != 1) { SSL_CTX_free(ctx_); ctx_ = nullptr; } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { @@ -8471,7 +9225,7 @@ inline SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, SSL_OP_NO_COMPRESSION | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); - SSL_CTX_set_min_proto_version(ctx_, TLS1_1_VERSION); + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); if (SSL_CTX_use_certificate(ctx_, cert) != 1 || SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) { @@ -8505,6 +9259,19 @@ inline bool SSLServer::is_valid() const { return ctx_; } inline SSL_CTX *SSLServer::ssl_context() const { return ctx_; } +inline void SSLServer::update_certs(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store) { + + std::lock_guard guard(ctx_mutex_); + + SSL_CTX_use_certificate(ctx_, cert); + SSL_CTX_use_PrivateKey(ctx_, private_key); + + if (client_ca_cert_store != nullptr) { + SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); + } +} + inline bool SSLServer::process_and_close_socket(socket_t sock) { auto ssl = detail::ssl_new( sock, ctx_, ctx_mutex_, @@ -8516,20 +9283,29 @@ inline bool SSLServer::process_and_close_socket(socket_t sock) { auto ret = false; if (ssl) { + std::string remote_addr; + int remote_port = 0; + detail::get_remote_ip_and_port(sock, remote_addr, remote_port); + + std::string local_addr; + int local_port = 0; + detail::get_local_ip_and_port(sock, local_addr, local_port); + ret = detail::process_server_socket_ssl( svr_sock_, ssl, sock, keep_alive_max_count_, keep_alive_timeout_sec_, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, - [this, ssl](Stream &strm, bool close_connection, - bool &connection_closed) { - return process_request(strm, close_connection, connection_closed, + [&](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, remote_addr, remote_port, local_addr, + local_port, close_connection, + connection_closed, [&](Request &req) { req.ssl = ssl; }); }); // Shutdown gracefully if the result seemed successful, non-gracefully if // the connection appeared to be closed. const bool shutdown_gracefully = ret; - detail::ssl_delete(ctx_mutex_, ssl, shutdown_gracefully); + detail::ssl_delete(ctx_mutex_, ssl, sock, shutdown_gracefully); } detail::shutdown_socket(sock); @@ -8551,6 +9327,8 @@ inline SSLClient::SSLClient(const std::string &host, int port, : ClientImpl(host, port, client_cert_path, client_key_path) { ctx_ = SSL_CTX_new(TLS_client_method()); + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); + detail::split(&host_[0], &host_[host_.size()], '.', [&](const char *b, const char *e) { host_components_.emplace_back(b, e); @@ -8758,36 +9536,47 @@ inline bool SSLClient::initialize_ssl(Socket &socket, Error &error) { } if (server_certificate_verification_) { - verify_result_ = SSL_get_verify_result(ssl2); + if (server_certificate_verifier_) { + if (!server_certificate_verifier_(ssl2)) { + error = Error::SSLServerVerification; + return false; + } + } else { + verify_result_ = SSL_get_verify_result(ssl2); - if (verify_result_ != X509_V_OK) { - error = Error::SSLServerVerification; - return false; + if (verify_result_ != X509_V_OK) { + error = Error::SSLServerVerification; + return false; + } + + auto server_cert = SSL_get1_peer_certificate(ssl2); + auto se = detail::scope_exit([&] { X509_free(server_cert); }); + + if (server_cert == nullptr) { + error = Error::SSLServerVerification; + return false; + } + + if (server_hostname_verification_) { + if (!verify_host(server_cert)) { + error = Error::SSLServerHostnameVerification; + return false; + } + } } - - auto server_cert = SSL_get1_peer_certificate(ssl2); - - if (server_cert == nullptr) { - error = Error::SSLServerVerification; - return false; - } - - if (!verify_host(server_cert)) { - X509_free(server_cert); - error = Error::SSLServerVerification; - return false; - } - X509_free(server_cert); } return true; }, [&](SSL *ssl2) { +#if defined(OPENSSL_IS_BORINGSSL) + SSL_set_tlsext_host_name(ssl2, host_.c_str()); +#else // NOTE: Direct call instead of using the OpenSSL macro to suppress // -Wold-style-cast warning - // SSL_set_tlsext_host_name(ssl2, host_.c_str()); SSL_ctrl(ssl2, SSL_CTRL_SET_TLSEXT_HOSTNAME, TLSEXT_NAMETYPE_host_name, static_cast(const_cast(host_.c_str()))); +#endif return true; }); @@ -8812,7 +9601,8 @@ inline void SSLClient::shutdown_ssl_impl(Socket &socket, return; } if (socket.ssl) { - detail::ssl_delete(ctx_mutex_, socket.ssl, shutdown_gracefully); + detail::ssl_delete(ctx_mutex_, socket.ssl, socket.sock, + shutdown_gracefully); socket.ssl = nullptr; } assert(socket.ssl == nullptr); @@ -8861,8 +9651,8 @@ SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { auto type = GEN_DNS; - struct in6_addr addr6 {}; - struct in_addr addr {}; + struct in6_addr addr6{}; + struct in_addr addr{}; size_t addr_len = 0; #ifndef __MINGW32__ @@ -8965,7 +9755,7 @@ inline Client::Client(const std::string &scheme_host_port, const std::string &client_cert_path, const std::string &client_key_path) { const static std::regex re( - R"((?:([a-z]+):\/\/)?(?:\[([\d:]+)\]|([^:/?#]+))(?::(\d+))?)"); + R"((?:([a-z]+):\/\/)?(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)"); std::smatch m; if (std::regex_match(scheme_host_port, m, re)) { @@ -9002,10 +9792,12 @@ inline Client::Client(const std::string &scheme_host_port, client_key_path); } } else { + // NOTE: Update TEST(UniversalClientImplTest, Ipv6LiteralAddress) + // if port param below changes. cli_ = detail::make_unique(scheme_host_port, 80, client_cert_path, client_key_path); } -} +} // namespace detail inline Client::Client(const std::string &host, int port) : cli_(detail::make_unique(host, port)) {} @@ -9111,15 +9903,30 @@ inline Result Client::Post(const std::string &path, const Headers &headers, const std::string &content_type) { return cli_->Post(path, headers, body, content_length, content_type); } +inline Result Client::Post(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return cli_->Post(path, headers, body, content_length, content_type, + progress); +} inline Result Client::Post(const std::string &path, const std::string &body, const std::string &content_type) { return cli_->Post(path, body, content_type); } +inline Result Client::Post(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Post(path, body, content_type, progress); +} inline Result Client::Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type) { return cli_->Post(path, headers, body, content_type); } +inline Result Client::Post(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Post(path, headers, body, content_type, progress); +} inline Result Client::Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type) { @@ -9150,6 +9957,10 @@ inline Result Client::Post(const std::string &path, const Headers &headers, const Params ¶ms) { return cli_->Post(path, headers, params); } +inline Result Client::Post(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress) { + return cli_->Post(path, headers, params, progress); +} inline Result Client::Post(const std::string &path, const MultipartFormDataItems &items) { return cli_->Post(path, items); @@ -9180,15 +9991,29 @@ inline Result Client::Put(const std::string &path, const Headers &headers, const std::string &content_type) { return cli_->Put(path, headers, body, content_length, content_type); } +inline Result Client::Put(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return cli_->Put(path, headers, body, content_length, content_type, progress); +} inline Result Client::Put(const std::string &path, const std::string &body, const std::string &content_type) { return cli_->Put(path, body, content_type); } +inline Result Client::Put(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Put(path, body, content_type, progress); +} inline Result Client::Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type) { return cli_->Put(path, headers, body, content_type); } +inline Result Client::Put(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Put(path, headers, body, content_type, progress); +} inline Result Client::Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type) { @@ -9219,6 +10044,10 @@ inline Result Client::Put(const std::string &path, const Headers &headers, const Params ¶ms) { return cli_->Put(path, headers, params); } +inline Result Client::Put(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress) { + return cli_->Put(path, headers, params, progress); +} inline Result Client::Put(const std::string &path, const MultipartFormDataItems &items) { return cli_->Put(path, items); @@ -9246,20 +10075,44 @@ inline Result Client::Patch(const std::string &path, const char *body, const std::string &content_type) { return cli_->Patch(path, body, content_length, content_type); } +inline Result Client::Patch(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + Progress progress) { + return cli_->Patch(path, body, content_length, content_type, progress); +} inline Result Client::Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type) { return cli_->Patch(path, headers, body, content_length, content_type); } +inline Result Client::Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + Progress progress) { + return cli_->Patch(path, headers, body, content_length, content_type, + progress); +} inline Result Client::Patch(const std::string &path, const std::string &body, const std::string &content_type) { return cli_->Patch(path, body, content_type); } +inline Result Client::Patch(const std::string &path, const std::string &body, + const std::string &content_type, + Progress progress) { + return cli_->Patch(path, body, content_type, progress); +} inline Result Client::Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type) { return cli_->Patch(path, headers, body, content_type); } +inline Result Client::Patch(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { + return cli_->Patch(path, headers, body, content_type, progress); +} inline Result Client::Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type) { @@ -9294,20 +10147,44 @@ inline Result Client::Delete(const std::string &path, const char *body, const std::string &content_type) { return cli_->Delete(path, body, content_length, content_type); } +inline Result Client::Delete(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + Progress progress) { + return cli_->Delete(path, body, content_length, content_type, progress); +} inline Result Client::Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type) { return cli_->Delete(path, headers, body, content_length, content_type); } +inline Result Client::Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + Progress progress) { + return cli_->Delete(path, headers, body, content_length, content_type, + progress); +} inline Result Client::Delete(const std::string &path, const std::string &body, const std::string &content_type) { return cli_->Delete(path, body, content_type); } +inline Result Client::Delete(const std::string &path, const std::string &body, + const std::string &content_type, + Progress progress) { + return cli_->Delete(path, body, content_type, progress); +} inline Result Client::Delete(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type) { return cli_->Delete(path, headers, body, content_type); } +inline Result Client::Delete(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { + return cli_->Delete(path, headers, body, content_type, progress); +} inline Result Client::Options(const std::string &path) { return cli_->Options(path); } @@ -9417,6 +10294,15 @@ inline void Client::set_proxy_digest_auth(const std::string &username, inline void Client::enable_server_certificate_verification(bool enabled) { cli_->enable_server_certificate_verification(enabled); } + +inline void Client::enable_server_hostname_verification(bool enabled) { + cli_->enable_server_hostname_verification(enabled); +} + +inline void Client::set_server_certificate_verifier( + std::function verifier) { + cli_->set_server_certificate_verifier(verifier); +} #endif inline void Client::set_logger(Logger logger) { diff --git a/examples/server/public/index.html.gz b/examples/server/public/index.html.gz index 3640a7a6cf..26f3583bdf 100644 Binary files a/examples/server/public/index.html.gz and b/examples/server/public/index.html.gz differ diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 127323e776..d1e8ee8291 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -19,6 +19,7 @@ #include "loading.html.hpp" #include +#include #include #include #include @@ -32,6 +33,8 @@ using json = nlohmann::ordered_json; +constexpr int HTTP_POLLING_SECONDS = 1; + enum stop_type { STOP_TYPE_NONE, STOP_TYPE_EOS, @@ -98,7 +101,7 @@ struct slot_params { int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit - std::vector lora; + std::vector lora; std::vector antiprompt; std::vector response_fields; @@ -198,15 +201,17 @@ struct server_task { bool metrics_reset_bucket = false; // used by SERVER_TASK_TYPE_SET_LORA - std::vector set_lora; + std::vector set_lora; server_task(server_task_type type) : type(type) {} static slot_params params_from_json_cmpl( - const llama_model * model, const llama_context * ctx, const common_params & params_base, const json & data) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + slot_params params; // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) @@ -329,7 +334,7 @@ struct server_task { const auto & logit_bias = data.find("logit_bias"); if (logit_bias != data.end() && logit_bias->is_array()) { - const int n_vocab = llama_n_vocab(model); + const int n_vocab = llama_vocab_n_tokens(vocab); for (const auto & el : *logit_bias) { // TODO: we may want to throw errors here, in case "el" is incorrect if (el.is_array() && el.size() == 2) { @@ -348,7 +353,7 @@ struct server_task { params.sampling.logit_bias.push_back({tok, bias}); } } else if (el[0].is_string()) { - auto toks = common_tokenize(model, el[0].get(), false); + auto toks = common_tokenize(vocab, el[0].get(), false); for (auto tok : toks) { params.sampling.logit_bias.push_back({tok, bias}); } @@ -1131,7 +1136,7 @@ struct server_slot { common_speculative * spec = nullptr; - std::vector lora; + std::vector lora; // the index relative to completion multi-task request size_t index = 0; @@ -1600,6 +1605,30 @@ struct server_response { // should never reach here } + // same as recv(), but have timeout in seconds + // if timeout is reached, nullptr is returned + server_task_result_ptr recv_with_timeout(const std::unordered_set & id_tasks, int timeout) { + while (true) { + std::unique_lock lock(mutex_results); + bool cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout), [&]{ + return !queue_results.empty(); + }); + if (!cr_res) { + return nullptr; + } + + for (int i = 0; i < (int) queue_results.size(); i++) { + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + server_task_result_ptr res = std::move(queue_results[i]); + queue_results.erase(queue_results.begin() + i); + return res; + } + } + } + + // should never reach here + } + // single-task version of recv() server_task_result_ptr recv(int id_task) { std::unordered_set id_tasks = {id_task}; @@ -1633,6 +1662,8 @@ struct server_context { llama_model * model = nullptr; llama_context * ctx = nullptr; + const llama_vocab * vocab = nullptr; + llama_model * model_dft = nullptr; llama_context_params cparams_dft; @@ -1690,10 +1721,12 @@ struct server_context { return false; } + vocab = llama_model_get_vocab(model); + n_ctx = llama_n_ctx(ctx); - add_bos_token = llama_add_bos_token(model); - has_eos_token = llama_token_eos(model) != LLAMA_TOKEN_NULL; + add_bos_token = llama_vocab_get_add_bos(vocab); + has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; if (!params_base.speculative.model.empty()) { SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str()); @@ -1736,7 +1769,8 @@ struct server_context { bool validate_builtin_chat_template() const { llama_chat_message chat[] = {{"user", "test"}}; - int32_t chat_res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0); + const char * tmpl = llama_model_chat_template(model); + const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0); return chat_res > 0; } @@ -1756,7 +1790,7 @@ struct server_context { if (model_dft) { slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); - slot.ctx_dft = llama_new_context_with_model(model_dft, cparams_dft); + slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); if (slot.ctx_dft == nullptr) { SRV_ERR("%s", "failed to create draft context\n"); return; @@ -1891,7 +1925,7 @@ struct server_context { } if (slot.params.ignore_eos && has_eos_token) { - slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY}); + slot.params.sampling.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY}); } { @@ -2047,14 +2081,14 @@ struct server_context { slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx); } - if (llama_token_is_eog(model, result.tok)) { + if (llama_vocab_is_eog(vocab, result.tok)) { slot.stop = STOP_TYPE_EOS; slot.has_next_token = false; SLT_DBG(slot, "%s", "stopped by EOS\n"); } - const auto n_ctx_train = llama_n_ctx_train(model); + const auto n_ctx_train = llama_model_n_ctx_train(model); if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { slot.truncated = true; @@ -2074,7 +2108,7 @@ struct server_context { void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) { size_t n_probs = slot.params.sampling.n_probs; - size_t n_vocab = llama_n_vocab(llama_get_model(ctx)); + size_t n_vocab = llama_vocab_n_tokens(vocab); if (post_sampling) { const auto * cur_p = common_sampler_get_candidates(slot.smpl); const size_t max_probs = cur_p->size; @@ -2225,7 +2259,7 @@ struct server_context { res->n_tokens = slot.n_prompt_tokens; res->oaicompat = slot.params.oaicompat; - const int n_embd = llama_n_embd(model); + const int n_embd = llama_model_n_embd(model); std::vector embd_res(n_embd, 0.0f); @@ -2315,10 +2349,21 @@ struct server_context { void receive_multi_results( const std::unordered_set & id_tasks, const std::function&)> & result_handler, - const std::function & error_handler) { + const std::function & error_handler, + const std::function & is_connection_closed) { std::vector results(id_tasks.size()); - for (size_t i = 0; i < id_tasks.size(); i++) { - server_task_result_ptr result = queue_results.recv(id_tasks); + for (int i = 0; i < (int)id_tasks.size(); i++) { + server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); + + if (is_connection_closed()) { + cancel_tasks(id_tasks); + return; + } + + if (result == nullptr) { + i--; // retry + continue; + } if (result->is_error()) { error_handler(result->to_json()); @@ -2342,10 +2387,20 @@ struct server_context { void receive_cmpl_results_stream( const std::unordered_set & id_tasks, const std::function & result_handler, - const std::function & error_handler) { + const std::function & error_handler, + const std::function & is_connection_closed) { size_t n_finished = 0; while (true) { - server_task_result_ptr result = queue_results.recv(id_tasks); + server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); + + if (is_connection_closed()) { + cancel_tasks(id_tasks); + return; + } + + if (result == nullptr) { + continue; // retry + } if (result->is_error()) { error_handler(result->to_json()); @@ -2927,7 +2982,7 @@ struct server_context { // make sure we're in the right embedding mode llama_set_embeddings(ctx, slot_batched->is_non_causal()); // apply lora, only need to do it once per batch - common_lora_adapters_apply(ctx, slot_batched->lora); + common_set_adapter_lora(ctx, slot_batched->lora); } // process the created batch of tokens @@ -3129,12 +3184,12 @@ struct server_context { json model_meta() const { return json { - {"vocab_type", llama_vocab_type (model)}, - {"n_vocab", llama_n_vocab (model)}, - {"n_ctx_train", llama_n_ctx_train (model)}, - {"n_embd", llama_n_embd (model)}, - {"n_params", llama_model_n_params(model)}, - {"size", llama_model_size (model)}, + {"vocab_type", llama_vocab_type (vocab)}, + {"n_vocab", llama_vocab_n_tokens (vocab)}, + {"n_ctx_train", llama_model_n_ctx_train(model)}, + {"n_embd", llama_model_n_embd (model)}, + {"n_params", llama_model_n_params (model)}, + {"size", llama_model_size (model)}, }; } }; @@ -3626,6 +3681,7 @@ int main(int argc, char ** argv) { const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok]( server_task_type type, json & data, + std::function is_connection_closed, httplib::Response & res, oaicompat_type oaicompat) { GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); @@ -3639,7 +3695,7 @@ int main(int argc, char ** argv) { std::vector tasks; try { - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, data.at("prompt"), true, true); + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true); tasks.reserve(tokenized_prompts.size()); for (size_t i = 0; i < tokenized_prompts.size(); i++) { server_task task = server_task(type); @@ -3649,7 +3705,6 @@ int main(int argc, char ** argv) { task.prompt_tokens = std::move(tokenized_prompts[i]); task.params = server_task::params_from_json_cmpl( - ctx_server.model, ctx_server.ctx, ctx_server.params_base, data); @@ -3688,7 +3743,7 @@ int main(int argc, char ** argv) { } }, [&](const json & error_data) { res_error(res, error_data); - }); + }, is_connection_closed); ctx_server.queue_results.remove_waiting_task_ids(task_ids); } else { @@ -3698,6 +3753,7 @@ int main(int argc, char ** argv) { if (res_json.is_array()) { for (const auto & res : res_json) { if (!server_sent_event(sink, "data", res)) { + // sending failed (HTTP connection closed), cancel the generation return false; } } @@ -3707,6 +3763,9 @@ int main(int argc, char ** argv) { } }, [&](const json & error_data) { server_sent_event(sink, "error", error_data); + }, [&sink]() { + // note: do not use req.is_connection_closed here because req is already destroyed + return !sink.is_writable(); }); if (oaicompat != OAICOMPAT_TYPE_NONE) { static const std::string ev_done = "data: [DONE]\n\n"; @@ -3729,6 +3788,7 @@ int main(int argc, char ** argv) { return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, data, + req.is_connection_closed, res, OAICOMPAT_TYPE_NONE); }; @@ -3738,6 +3798,7 @@ int main(int argc, char ** argv) { return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, data, + req.is_connection_closed, res, OAICOMPAT_TYPE_COMPLETION); }; @@ -3745,13 +3806,13 @@ int main(int argc, char ** argv) { const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { // check model compatibility std::string err; - if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) { + if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) { err += "prefix token is missing. "; } - if (llama_token_fim_suf(ctx_server.model) == LLAMA_TOKEN_NULL) { + if (llama_vocab_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) { err += "suffix token is missing. "; } - if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) { + if (llama_vocab_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) { err += "middle token is missing. "; } if (!err.empty()) { @@ -3797,10 +3858,10 @@ int main(int argc, char ** argv) { data["input_extra"] = input_extra; // default to empty array if it's not exist std::string prompt = json_value(data, "prompt", std::string()); - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, false, true); + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, false, true); SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); data["prompt"] = format_infill( - ctx_server.ctx, + ctx_server.vocab, data.at("input_prefix"), data.at("input_suffix"), data.at("input_extra"), @@ -3814,6 +3875,7 @@ int main(int argc, char ** argv) { return handle_completions_impl( SERVER_TASK_TYPE_INFILL, data, + req.is_connection_closed, res, OAICOMPAT_TYPE_NONE); // infill is not OAI compatible }; @@ -3828,6 +3890,7 @@ int main(int argc, char ** argv) { return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, data, + req.is_connection_closed, res, OAICOMPAT_TYPE_CHAT); }; @@ -3857,7 +3920,7 @@ int main(int argc, char ** argv) { const bool add_special = json_value(body, "add_special", false); const bool with_pieces = json_value(body, "with_pieces", false); - llama_tokens tokens = tokenize_mixed(ctx_server.ctx, body.at("content"), add_special, true); + llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, true); if (with_pieces) { for (const auto& token : tokens) { @@ -3933,7 +3996,7 @@ int main(int argc, char ** argv) { } } - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true); + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); for (const auto & tokens : tokenized_prompts) { // this check is necessary for models that do not add BOS token to the input if (tokens.empty()) { @@ -3974,7 +4037,7 @@ int main(int argc, char ** argv) { }, [&](const json & error_data) { res_error(res, error_data); error = true; - }); + }, req.is_connection_closed); ctx_server.queue_results.remove_waiting_task_ids(task_ids); } @@ -4033,20 +4096,20 @@ int main(int argc, char ** argv) { return; } - llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.ctx, query, /* add_special */ false, true)[0]; + llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.vocab, query, /* add_special */ false, true)[0]; // create and queue the task json responses = json::array(); bool error = false; { std::vector tasks; - std::vector tokenized_docs = tokenize_input_prompts(ctx_server.ctx, documents, /* add_special */ false, true); + std::vector tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true); tasks.reserve(tokenized_docs.size()); for (size_t i = 0; i < tokenized_docs.size(); i++) { server_task task = server_task(SERVER_TASK_TYPE_RERANK); task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; - task.prompt_tokens = format_rerank(ctx_server.model, tokenized_query, tokenized_docs[i]); + task.prompt_tokens = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]); tasks.push_back(task); } @@ -4064,7 +4127,7 @@ int main(int argc, char ** argv) { }, [&](const json & error_data) { res_error(res, error_data); error = true; - }); + }, req.is_connection_closed); } if (error) { diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index e5e3b60771..c1fc124627 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -1,4 +1,5 @@ import pytest +import requests import time from openai import OpenAI from utils import * @@ -405,3 +406,23 @@ def test_n_probs_post_sampling(): assert "bytes" in prob and type(prob["bytes"]) == list # because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs assert any(prob["prob"] == 1.0 for prob in tok["top_probs"]) + + +def test_cancel_request(): + global server + server.n_ctx = 4096 + server.n_predict = -1 + server.n_slots = 1 + server.server_slots = True + server.start() + # send a request that will take a long time, but cancel it before it finishes + try: + server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + }, timeout=0.1) + except requests.exceptions.ReadTimeout: + pass # expected + # make sure the slot is free + time.sleep(1) # wait for HTTP_POLLING_SECONDS + res = server.make_request("GET", "/slots") + assert res.body[0]["is_processing"] == False diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index a1a94d0f15..73be4c92f0 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -219,17 +219,18 @@ class ServerProcess: path: str, data: dict | Any | None = None, headers: dict | None = None, + timeout: float | None = None, ) -> ServerResponse: url = f"http://{self.server_host}:{self.server_port}{path}" parse_body = False if method == "GET": - response = requests.get(url, headers=headers) + response = requests.get(url, headers=headers, timeout=timeout) parse_body = True elif method == "POST": - response = requests.post(url, headers=headers, json=data) + response = requests.post(url, headers=headers, json=data, timeout=timeout) parse_body = True elif method == "OPTIONS": - response = requests.options(url, headers=headers) + response = requests.options(url, headers=headers, timeout=timeout) else: raise ValueError(f"Unimplemented method: {method}") result = ServerResponse() diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index ad130d490e..699480f904 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -118,7 +118,7 @@ static json json_get_nested_values(const std::vector & paths, const * - only string, example: "string" * - mixed string and tokens, example: [12, 34, "string", 56, 78] */ -static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) { +static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) { // If `add_bos` is true, we only add BOS, when json_prompt is a string, // or the first element of the json_prompt array is a string. llama_tokens prompt_tokens; @@ -131,10 +131,10 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_ llama_tokens p; if (first) { - p = common_tokenize(ctx, s, add_special, parse_special); + p = common_tokenize(vocab, s, add_special, parse_special); first = false; } else { - p = common_tokenize(ctx, s, false, parse_special); + p = common_tokenize(vocab, s, false, parse_special); } prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); @@ -148,7 +148,7 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_ } } else { auto s = json_prompt.template get(); - prompt_tokens = common_tokenize(ctx, s, add_special, parse_special); + prompt_tokens = common_tokenize(vocab, s, add_special, parse_special); } return prompt_tokens; @@ -166,11 +166,11 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_ * - "prompt": [[12, 34, 56], [78, 90, 12]] * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]] */ -static std::vector tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) { +static std::vector tokenize_input_prompts(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) { std::vector result; if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) { // string or mixed - result.push_back(tokenize_mixed(ctx, json_prompt, add_special, parse_special)); + result.push_back(tokenize_mixed(vocab, json_prompt, add_special, parse_special)); } else if (json_is_array_of_numbers(json_prompt)) { // array of tokens result.push_back(json_prompt.get()); @@ -179,7 +179,7 @@ static std::vector tokenize_input_prompts(llama_context * ctx, con result.reserve(json_prompt.size()); for (const auto & p : json_prompt) { if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) { - result.push_back(tokenize_mixed(ctx, p, add_special, parse_special)); + result.push_back(tokenize_mixed(vocab, p, add_special, parse_special)); } else if (json_is_array_of_numbers(p)) { // array of tokens result.push_back(p.get()); @@ -231,21 +231,23 @@ static size_t validate_utf8(const std::string& text) { // // format rerank task: [BOS]query[EOS][SEP]doc[EOS] -static llama_tokens format_rerank(const struct llama_model * model, const llama_tokens & query, const llama_tokens & doc) { +static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) { llama_tokens result; + result.reserve(doc.size() + query.size() + 4); - result.push_back(llama_token_bos(model)); + result.push_back(llama_vocab_bos(vocab)); result.insert(result.end(), query.begin(), query.end()); - result.push_back(llama_token_eos(model)); - result.push_back(llama_token_sep(model)); + result.push_back(llama_vocab_eos(vocab)); + result.push_back(llama_vocab_sep(vocab)); result.insert(result.end(), doc.begin(), doc.end()); - result.push_back(llama_token_eos(model)); + result.push_back(llama_vocab_eos(vocab)); + return result; } // format infill task static llama_tokens format_infill( - const llama_context * ctx, + const llama_vocab * vocab, const json & input_prefix, const json & input_suffix, const json & input_extra, @@ -272,15 +274,14 @@ static llama_tokens format_infill( llama_tokens extra_tokens; extra_tokens.reserve(n_ctx); - auto model = llama_get_model(ctx); - auto tokens_prefix = tokenize_mixed(ctx, input_prefix, false, false); - auto tokens_suffix = tokenize_mixed(ctx, input_suffix, false, false); + auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false); + auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false); - if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) { + if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) { // TODO: make project name an input - static const auto k_fim_repo = common_tokenize(ctx, "myproject\n", false, false); + static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false); - extra_tokens.push_back(llama_token_fim_rep(model)); + extra_tokens.push_back(llama_vocab_fim_rep(vocab)); extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end()); } for (const auto & chunk : input_extra) { @@ -288,28 +289,28 @@ static llama_tokens format_infill( const std::string text = json_value(chunk, "text", std::string()); const std::string filename = json_value(chunk, "filename", std::string("tmp")); - if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) { - const auto k_fim_file = common_tokenize(ctx, filename + "\n", false, false); + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false); - extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model)); + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); } else { // chunk separator in binary form to avoid confusing the AI static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00}; - static const auto k_chunk_prefix_tokens = common_tokenize(ctx, k_chunk_prefix_str, false, false); + static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false); extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end()); } - const auto chunk_tokens = common_tokenize(ctx, text, false, false); + const auto chunk_tokens = common_tokenize(vocab, text, false, false); extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); } - if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) { + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { // TODO: current filename - static const auto k_fim_file = common_tokenize(ctx, "filename\n", false, false); + static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false); - extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model)); + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); } @@ -325,15 +326,15 @@ static llama_tokens format_infill( tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take); tokens_suffix.resize(n_suffix_take); - tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model)); + tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab)); tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end()); - tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model)); + tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab)); auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix; auto embd_end = spm_infill ? tokens_prefix : tokens_suffix; - if (llama_add_bos_token(model)) { - embd_inp.insert(embd_inp.begin(), llama_token_bos(model)); + if (llama_vocab_get_add_bos(vocab)) { + embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); } SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size()); @@ -342,7 +343,7 @@ static llama_tokens format_infill( embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end()); embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); - embd_inp.push_back(llama_token_fim_mid(model)); + embd_inp.push_back(llama_vocab_fim_mid(vocab)); return embd_inp; } @@ -764,14 +765,18 @@ static json format_logit_bias(const std::vector & logit_bias) return data; } -static std::string safe_json_to_str(json data) { +static std::string safe_json_to_str(const json & data) { return data.dump(-1, ' ', false, json::error_handler_t::replace); } static std::vector get_token_probabilities(llama_context * ctx, int idx) { std::vector cur; const auto * logits = llama_get_logits_ith(ctx, idx); - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_vocab = llama_vocab_n_tokens(vocab); cur.resize(n_vocab); for (llama_token token_id = 0; token_id < n_vocab; token_id++) { @@ -799,8 +804,8 @@ static std::vector get_token_probabilities(llama_context * ctx } static bool are_lora_equal( - const std::vector & l1, - const std::vector & l2) { + const std::vector & l1, + const std::vector & l2) { if (l1.size() != l2.size()) { return false; } @@ -814,10 +819,10 @@ static bool are_lora_equal( } // parse lora config from JSON request, returned a copy of lora_base with updated scale -static std::vector parse_lora_request( - const std::vector & lora_base, +static std::vector parse_lora_request( + const std::vector & lora_base, const json & data) { - std::vector lora(lora_base); + std::vector lora(lora_base); int max_idx = lora.size(); // clear existing value diff --git a/examples/server/webui/index.html b/examples/server/webui/index.html index 86a79b77f3..2180ef4ad9 100644 --- a/examples/server/webui/index.html +++ b/examples/server/webui/index.html @@ -37,7 +37,7 @@
+ }" @click="setViewingConv(conv.id)" dir="auto"> {{ conv.messages[0].content }}
@@ -156,6 +156,7 @@ @keydown.enter.shift.exact.prevent="inputMsg += '\n'" :disabled="isGenerating" id="msg-input" + dir="auto" > @@ -248,6 +249,7 @@