Merge remote-tracking branch 'upstream/master' into backend-sampling

This commit is contained in:
Daniel Bevenius 2025-12-22 06:46:54 +01:00
commit f1310ab904
No known key found for this signature in database
55 changed files with 2336 additions and 878 deletions

View File

@ -70,6 +70,7 @@ jobs:
with:
key: macOS-latest-cmake-arm64
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build
id: cmake_build
@ -106,6 +107,7 @@ jobs:
with:
key: macOS-latest-cmake-x64
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build
id: cmake_build
@ -142,6 +144,7 @@ jobs:
with:
key: macOS-latest-cmake-arm64-webgpu
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dawn Dependency
id: dawn-depends
@ -195,6 +198,7 @@ jobs:
with:
key: ubuntu-cpu-cmake-${{ matrix.build }}
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build Dependencies
id: build_depends
@ -276,6 +280,7 @@ jobs:
with:
key: ubuntu-latest-cmake-sanitizer-${{ matrix.sanitizer }}
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies
id: depends
@ -396,6 +401,7 @@ jobs:
with:
key: ubuntu-24-cmake-vulkan-deb
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies
id: depends
@ -431,6 +437,7 @@ jobs:
with:
key: ubuntu-24-cmake-vulkan
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies
id: depends
@ -490,6 +497,7 @@ jobs:
with:
key: ubuntu-24-cmake-webgpu
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies
id: depends
@ -562,6 +570,7 @@ jobs:
with:
key: ubuntu-latest-wasm-webgpu
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Install Emscripten
run: |
@ -609,6 +618,7 @@ jobs:
with:
key: ubuntu-22-cmake-hip
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build with native CMake HIP support
id: cmake_build
@ -641,6 +651,7 @@ jobs:
with:
key: ubuntu-22-cmake-musa
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build with native CMake MUSA support
id: cmake_build
@ -688,6 +699,7 @@ jobs:
with:
key: ubuntu-22-cmake-sycl
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build
id: cmake_build
@ -738,6 +750,7 @@ jobs:
with:
key: ubuntu-22-cmake-sycl-fp16
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build
id: cmake_build
@ -771,6 +784,7 @@ jobs:
with:
key: macOS-latest-cmake-ios
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build
id: cmake_build
@ -802,6 +816,7 @@ jobs:
with:
key: macOS-latest-cmake-tvos
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build
id: cmake_build
@ -863,6 +878,7 @@ jobs:
with:
key: macOS-latest-swift
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Download xcframework artifact
uses: actions/download-artifact@v4
@ -905,6 +921,7 @@ jobs:
key: windows-msys2
variant: ccache
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Setup ${{ matrix.sys }}
uses: msys2/setup-msys2@v2
@ -973,6 +990,7 @@ jobs:
key: windows-latest-cmake-${{ matrix.build }}
variant: ccache
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Download OpenBLAS
id: get_openblas
@ -1077,6 +1095,7 @@ jobs:
with:
key: ubuntu-latest-cmake-cuda
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build with CMake
# Remove GGML_CUDA_CUB_3DOT2 flag once CCCL 3.2 is bundled withing CTK and that CTK version is used in this project
@ -1111,6 +1130,7 @@ jobs:
key: windows-cuda-${{ matrix.cuda }}
variant: ccache
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Install Cuda Toolkit
uses: ./.github/actions/windows-setup-cuda
@ -1164,6 +1184,7 @@ jobs:
key: windows-latest-cmake-sycl
variant: ccache
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Install
run: |
@ -1225,6 +1246,7 @@ jobs:
with:
key: ${{ github.job }}
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build
id: cmake_build
@ -1470,6 +1492,7 @@ jobs:
with:
key: ggml-ci-x64-cpu-low-perf
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies
id: depends
@ -1495,6 +1518,7 @@ jobs:
with:
key: ggml-ci-arm64-cpu-low-perf
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies
id: depends
@ -1520,6 +1544,7 @@ jobs:
with:
key: ggml-ci-x64-cpu-high-perf
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies
id: depends
@ -1545,6 +1570,7 @@ jobs:
with:
key: ggml-ci-arm64-cpu-high-perf
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies
id: depends
@ -1570,6 +1596,7 @@ jobs:
with:
key: ggml-ci-arm64-cpu-high-perf-sve
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies
id: depends
@ -1705,6 +1732,7 @@ jobs:
with:
key: ggml-ci-arm64-cpu-kleidiai
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies
id: depends
@ -2088,6 +2116,7 @@ jobs:
with:
key: ggml-ci-arm64-graviton4-kleidiai
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Test
id: ggml-ci

View File

@ -66,16 +66,9 @@ jobs:
id: pack_artifacts
run: |
cp LICENSE ./build/bin/
zip -y -r llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip ./build/bin/*
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- name: Upload artifacts (zip)
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip
name: llama-bin-macos-arm64.zip
- name: Upload artifacts (tar)
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz
@ -127,16 +120,9 @@ jobs:
id: pack_artifacts
run: |
cp LICENSE ./build/bin/
zip -y -r llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip ./build/bin/*
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- name: Upload artifacts (zip)
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip
name: llama-bin-macos-x64.zip
- name: Upload artifacts (tar)
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz
@ -196,16 +182,9 @@ jobs:
id: pack_artifacts
run: |
cp LICENSE ./build/bin/
zip -y -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip ./build/bin/*
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- name: Upload artifacts (zip)
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip
name: llama-bin-ubuntu-${{ matrix.build }}.zip
- name: Upload artifacts (tar)
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz
@ -256,16 +235,9 @@ jobs:
id: pack_artifacts
run: |
cp LICENSE ./build/bin/
zip -y -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip ./build/bin/*
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- name: Upload artifacts (zip)
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip
name: llama-bin-ubuntu-vulkan-x64.zip
- name: Upload artifacts (tar)
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz
@ -718,16 +690,9 @@ jobs:
- name: Pack artifacts
id: pack_artifacts
run: |
zip -y -r llama-${{ steps.tag.outputs.name }}-xcframework.zip build-apple/llama.xcframework
tar -czvf llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz -C build-apple llama.xcframework
- name: Upload artifacts (zip)
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-xcframework.zip
name: llama-${{ steps.tag.outputs.name }}-xcframework.zip
- name: Upload artifacts (tar)
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz
@ -799,7 +764,7 @@ jobs:
cp LICENSE ./build/bin/
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- name: Upload artifacts (tar)
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz
@ -891,9 +856,6 @@ jobs:
with:
tag_name: ${{ steps.tag.outputs.name }}
body: |
> [!WARNING]
> **Release Format Update**: Linux releases will soon use .tar.gz archives instead of .zip. Please make the necessary changes to your deployment scripts.
<details open>
${{ github.event.head_commit.message }}
@ -913,8 +875,8 @@ jobs:
**Windows:**
- [Windows x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-x64.zip)
- [Windows arm64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-arm64.zip)
- [Windows x64 (CUDA 12)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-12.4-x64.zip)
- [Windows x64 (CUDA 13)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-13.1-x64.zip)
- [Windows x64 (CUDA 12)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-12.4-x64.zip) - [CUDA 12.4 DLLs](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-12.4-x64.zip)
- [Windows x64 (CUDA 13)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-13.1-x64.zip) - [CUDA 13.1 DLLs](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-13.1-x64.zip)
- [Windows x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-vulkan-x64.zip)
- [Windows x64 (SYCL)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip)
- [Windows x64 (HIP)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-hip-radeon-x64.zip)

View File

@ -1,4 +1,4 @@
# Copilot Instructions for llama.cpp
# Instructions for llama.cpp
## Repository Overview
@ -12,6 +12,24 @@ llama.cpp is a large-scale C/C++ project for efficient LLM (Large Language Model
- **Backends supported**: CPU (AVX/NEON/RVV optimized), CUDA, Metal, Vulkan, SYCL, ROCm, MUSA
- **License**: MIT
## Disclose AI Usage
It is crucial to remind contributors that the project mandates disclosure of any AI usage in pull requests. This requirement stems from the potential for AI-generated code to include suboptimal optimizations and hidden bugs, owing to the inherent overconfidence in AI outputs.
When generating significant portions of code, address this by:
- Informing the user that AI-generated content may be rejected by maintainers.
- Clearly marking AI-generated code in commit messages and comments.
- Example of commit message: `[AI] Fix a race condition in ...`
- Example of code comment: `// [AI] spawn a new thread ...`
These measures apply to:
- Changes resulting in large portions of code or complex logic.
- Modifications or additions to public APIs in `llama.h`, `ggml.h`, or `mtmd.h`.
- Backend-related changes, such as those involving CPU, CUDA, Metal, Vulkan, etc.
- Modifications to `tools/server`.
Note: These measures can be omitted for small fixes or trivial changes.
## Build Instructions
### Prerequisites
@ -251,6 +269,7 @@ Primary tools:
- **Cross-platform compatibility**: Test on Linux, macOS, Windows when possible
- **Performance focus**: This is a performance-critical inference library
- **API stability**: Changes to `include/llama.h` require careful consideration
- **Disclose AI Usage**: Refer to the "Disclose AI Usage" earlier in this document
### Git Workflow
- Always create feature branches from `master`

View File

@ -85,6 +85,9 @@ add_library(${TARGET} STATIC
unicode.h
)
target_include_directories(${TARGET} PUBLIC . ../vendor)
target_compile_features (${TARGET} PUBLIC cxx_std_17)
if (BUILD_SHARED_LIBS)
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()
@ -151,9 +154,7 @@ if (LLAMA_LLGUIDANCE)
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance ${LLGUIDANCE_PLATFORM_LIBS})
endif ()
target_include_directories(${TARGET} PUBLIC . ../vendor)
target_compile_features (${TARGET} PUBLIC cxx_std_17)
target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
target_link_libraries(${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
#

View File

@ -96,6 +96,11 @@ common_arg & common_arg::set_sparam() {
return *this;
}
common_arg & common_arg::set_preset_only() {
is_preset_only = true;
return *this;
}
bool common_arg::in_example(enum llama_example ex) {
return examples.find(ex) != examples.end();
}
@ -772,6 +777,11 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
}
auto opt = *arg_to_options[arg];
std::string val;
if (opt.value_hint == nullptr && opt.value_hint_2 == nullptr) {
// bool arg (need to reverse the meaning for negative args)
bool is_neg = std::find(opt.args_neg.begin(), opt.args_neg.end(), arg) != opt.args_neg.end();
val = is_neg ? "0" : "1";
}
if (opt.value_hint != nullptr) {
// arg with single value
check_arg(i);
@ -1139,7 +1149,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--cache-ram", "-cram"}, "N",
{"-cram", "--cache-ram"}, "N",
string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)", params.cache_ram_mib),
[](common_params & params, int value) {
@ -1147,7 +1157,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_env("LLAMA_ARG_CACHE_RAM").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--kv-unified", "-kvu"},
{"-kvu", "--kv-unified"},
"use single unified KV buffer shared across all sequences (default: enabled if number of slots is auto)",
[](common_params & params) {
params.kv_unified = true;
@ -1415,7 +1425,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_sparam());
add_opt(common_arg(
{"--sampling-seq", "--sampler-seq"}, "SEQUENCE",
{"--sampler-seq", "--sampling-seq"}, "SEQUENCE",
string_format("simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str()),
[](common_params & params, const std::string & value) {
params.sampling.samplers = common_sampler_types_from_chars(value);
@ -2080,26 +2090,26 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
));
add_opt(common_arg(
{"--override-tensor", "-ot"}, "<tensor name pattern>=<buffer type>,...",
{"-ot", "--override-tensor"}, "<tensor name pattern>=<buffer type>,...",
"override tensor buffer type", [](common_params & params, const std::string & value) {
parse_tensor_buffer_overrides(value, params.tensor_buft_overrides);
}
));
add_opt(common_arg(
{"--override-tensor-draft", "-otd"}, "<tensor name pattern>=<buffer type>,...",
{"-otd", "--override-tensor-draft"}, "<tensor name pattern>=<buffer type>,...",
"override tensor buffer type for draft model", [](common_params & params, const std::string & value) {
parse_tensor_buffer_overrides(value, params.speculative.tensor_buft_overrides);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--cpu-moe", "-cmoe"},
{"-cmoe", "--cpu-moe"},
"keep all Mixture of Experts (MoE) weights in the CPU",
[](common_params & params) {
params.tensor_buft_overrides.push_back(llm_ffn_exps_cpu_override());
}
).set_env("LLAMA_ARG_CPU_MOE"));
add_opt(common_arg(
{"--n-cpu-moe", "-ncmoe"}, "N",
{"-ncmoe", "--n-cpu-moe"}, "N",
"keep the Mixture of Experts (MoE) weights of the first N layers in the CPU",
[](common_params & params, int value) {
if (value < 0) {
@ -2114,14 +2124,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_env("LLAMA_ARG_N_CPU_MOE"));
add_opt(common_arg(
{"--cpu-moe-draft", "-cmoed"},
{"-cmoed", "--cpu-moe-draft"},
"keep all Mixture of Experts (MoE) weights in the CPU for the draft model",
[](common_params & params) {
params.speculative.tensor_buft_overrides.push_back(llm_ffn_exps_cpu_override());
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_CPU_MOE_DRAFT"));
add_opt(common_arg(
{"--n-cpu-moe-draft", "-ncmoed"}, "N",
{"-ncmoed", "--n-cpu-moe-draft"}, "N",
"keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model",
[](common_params & params, int value) {
if (value < 0) {
@ -2649,7 +2659,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
add_opt(common_arg(
{"--reranking", "--rerank"},
{"--rerank", "--reranking"},
string_format("enable reranking endpoint on server (default: %s)", "disabled"),
[](common_params & params) {
params.embedding = true;
@ -2884,6 +2894,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.lora_init_without_apply = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--sleep-idle-seconds"}, "SECONDS",
string_format("number of seconds of idleness after which the server will sleep (default: %d; -1 = disabled)", params.sleep_idle_seconds),
[](common_params & params, int value) {
if (value == 0 || value < -1) {
throw std::invalid_argument("invalid value: cannot be 0 or less than -1");
}
params.sleep_idle_seconds = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--simple-io"},
"use basic IO for better compatibility in subprocesses and limited consoles",
@ -3120,7 +3140,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--draft-max", "--draft", "--draft-n"}, "N",
{"--draft", "--draft-n", "--draft-max"}, "N",
string_format("number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max),
[](common_params & params, int value) {
params.speculative.n_max = value;
@ -3496,3 +3516,24 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
return ctx_arg;
}
void common_params_add_preset_options(std::vector<common_arg> & args) {
// arguments below won't be treated as CLI args, only preset options
args.push_back(common_arg(
{"load-on-startup"}, "NAME",
"in server router mode, autoload this model on startup",
[](common_params &, const std::string &) { /* unused */ }
).set_env(COMMON_ARG_PRESET_LOAD_ON_STARTUP).set_preset_only());
// args.push_back(common_arg(
// {"pin"},
// "in server router mode, do not unload this model if models_max is exceeded",
// [](common_params &) { /* unused */ }
// ).set_preset_only());
// args.push_back(common_arg(
// {"unload-idle-seconds"}, "SECONDS",
// "in server router mode, unload models idle for more than this many seconds",
// [](common_params &, int) { /* unused */ }
// ).set_preset_only());
}

View File

@ -8,6 +8,9 @@
#include <vector>
#include <cstring>
// pseudo-env variable to identify preset-only arguments
#define COMMON_ARG_PRESET_LOAD_ON_STARTUP "__PRESET_LOAD_ON_STARTUP"
//
// CLI argument parsing
//
@ -22,6 +25,7 @@ struct common_arg {
const char * env = nullptr;
std::string help;
bool is_sparam = false; // is current arg a sampling param?
bool is_preset_only = false; // is current arg preset-only (not treated as CLI arg)
void (*handler_void) (common_params & params) = nullptr;
void (*handler_string) (common_params & params, const std::string &) = nullptr;
void (*handler_str_str)(common_params & params, const std::string &, const std::string &) = nullptr;
@ -70,6 +74,7 @@ struct common_arg {
common_arg & set_excludes(std::initializer_list<enum llama_example> excludes);
common_arg & set_env(const char * env);
common_arg & set_sparam();
common_arg & set_preset_only();
bool in_example(enum llama_example ex);
bool is_exclude(enum llama_example ex);
bool get_value_from_env(std::string & output) const;
@ -114,9 +119,13 @@ struct common_params_context {
bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
// parse input arguments from CLI into a map
// TODO: support repeated args in the future
bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map);
// populate preset-only arguments
// these arguments are not treated as command line arguments
// see: https://github.com/ggml-org/llama.cpp/issues/18163
void common_params_add_preset_options(std::vector<common_arg> & args);
// initialize argument parser context - used by test-arg-parser and preset
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);

View File

@ -477,7 +477,8 @@ struct common_params {
bool enable_chat_template = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
int reasoning_budget = -1;
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time
std::vector<std::string> api_keys;

View File

@ -2,6 +2,7 @@
#include "preset.h"
#include "peg-parser.h"
#include "log.h"
#include "download.h"
#include <fstream>
#include <sstream>
@ -15,11 +16,22 @@ static std::string rm_leading_dashes(const std::string & str) {
return str.substr(pos);
}
std::vector<std::string> common_preset::to_args() const {
std::vector<std::string> common_preset::to_args(const std::string & bin_path) const {
std::vector<std::string> args;
if (!bin_path.empty()) {
args.push_back(bin_path);
}
for (const auto & [opt, value] : options) {
args.push_back(opt.args.back()); // use the last arg as the main arg
if (opt.is_preset_only) {
continue; // skip preset-only options (they are not CLI args)
}
// use the last arg as the main arg (i.e. --long-form)
args.push_back(opt.args.back());
// handle value(s)
if (opt.value_hint == nullptr && opt.value_hint_2 == nullptr) {
// flag option, no value
if (common_arg_utils::is_falsey(value)) {
@ -63,6 +75,52 @@ std::string common_preset::to_ini() const {
return ss.str();
}
void common_preset::set_option(const common_preset_context & ctx, const std::string & env, const std::string & value) {
// try if option exists, update it
for (auto & [opt, val] : options) {
if (opt.env && env == opt.env) {
val = value;
return;
}
}
// if option does not exist, we need to add it
if (ctx.key_to_opt.find(env) == ctx.key_to_opt.end()) {
throw std::runtime_error(string_format(
"%s: option with env '%s' not found in ctx_params",
__func__, env.c_str()
));
}
options[ctx.key_to_opt.at(env)] = value;
}
void common_preset::unset_option(const std::string & env) {
for (auto it = options.begin(); it != options.end(); ) {
const common_arg & opt = it->first;
if (opt.env && env == opt.env) {
it = options.erase(it);
return;
} else {
++it;
}
}
}
bool common_preset::get_option(const std::string & env, std::string & value) const {
for (const auto & [opt, val] : options) {
if (opt.env && env == opt.env) {
value = val;
return true;
}
}
return false;
}
void common_preset::merge(const common_preset & other) {
for (const auto & [opt, val] : other.options) {
options[opt] = val; // overwrite existing options
}
}
static std::map<std::string, std::map<std::string, std::string>> parse_ini_from_file(const std::string & path) {
std::map<std::string, std::map<std::string, std::string>> parsed;
@ -172,9 +230,14 @@ static std::string parse_bool_arg(const common_arg & arg, const std::string & ke
return value;
}
common_presets common_presets_load(const std::string & path, common_params_context & ctx_params) {
common_preset_context::common_preset_context(llama_example ex)
: ctx_params(common_params_parser_init(default_params, ex)) {
common_params_add_preset_options(ctx_params.options);
key_to_opt = get_map_key_opt(ctx_params);
}
common_presets common_preset_context::load_from_ini(const std::string & path, common_preset & global) const {
common_presets out;
auto key_to_opt = get_map_key_opt(ctx_params);
auto ini_data = parse_ini_from_file(path);
for (auto section : ini_data) {
@ -188,7 +251,7 @@ common_presets common_presets_load(const std::string & path, common_params_conte
for (const auto & [key, value] : section.second) {
LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str());
if (key_to_opt.find(key) != key_to_opt.end()) {
auto & opt = key_to_opt[key];
const auto & opt = key_to_opt.at(key);
if (is_bool_arg(opt)) {
preset.options[opt] = parse_bool_arg(opt, key, value);
} else {
@ -199,8 +262,137 @@ common_presets common_presets_load(const std::string & path, common_params_conte
// TODO: maybe warn about unknown key?
}
}
if (preset.name == "*") {
// handle global preset
global = preset;
} else {
out[preset.name] = preset;
}
}
return out;
}
common_presets common_preset_context::load_from_cache() const {
common_presets out;
auto cached_models = common_list_cached_models();
for (const auto & model : cached_models) {
common_preset preset;
preset.name = model.to_string();
preset.set_option(*this, "LLAMA_ARG_HF_REPO", model.to_string());
out[preset.name] = preset;
}
return out;
}
struct local_model {
std::string name;
std::string path;
std::string path_mmproj;
};
common_presets common_preset_context::load_from_models_dir(const std::string & models_dir) const {
if (!std::filesystem::exists(models_dir) || !std::filesystem::is_directory(models_dir)) {
throw std::runtime_error(string_format("error: '%s' does not exist or is not a directory\n", models_dir.c_str()));
}
std::vector<local_model> models;
auto scan_subdir = [&models](const std::string & subdir_path, const std::string & name) {
auto files = fs_list(subdir_path, false);
common_file_info model_file;
common_file_info first_shard_file;
common_file_info mmproj_file;
for (const auto & file : files) {
if (string_ends_with(file.name, ".gguf")) {
if (file.name.find("mmproj") != std::string::npos) {
mmproj_file = file;
} else if (file.name.find("-00001-of-") != std::string::npos) {
first_shard_file = file;
} else {
model_file = file;
}
}
}
// single file model
local_model model{
/* name */ name,
/* path */ first_shard_file.path.empty() ? model_file.path : first_shard_file.path,
/* path_mmproj */ mmproj_file.path // can be empty
};
if (!model.path.empty()) {
models.push_back(model);
}
};
auto files = fs_list(models_dir, true);
for (const auto & file : files) {
if (file.is_dir) {
scan_subdir(file.path, file.name);
} else if (string_ends_with(file.name, ".gguf")) {
// single file model
std::string name = file.name;
string_replace_all(name, ".gguf", "");
local_model model{
/* name */ name,
/* path */ file.path,
/* path_mmproj */ ""
};
models.push_back(model);
}
}
// convert local models to presets
common_presets out;
for (const auto & model : models) {
common_preset preset;
preset.name = model.name;
preset.set_option(*this, "LLAMA_ARG_MODEL", model.path);
if (!model.path_mmproj.empty()) {
preset.set_option(*this, "LLAMA_ARG_MMPROJ", model.path_mmproj);
}
out[preset.name] = preset;
}
return out;
}
common_preset common_preset_context::load_from_args(int argc, char ** argv) const {
common_preset preset;
preset.name = COMMON_PRESET_DEFAULT_NAME;
bool ok = common_params_to_map(argc, argv, ctx_params.ex, preset.options);
if (!ok) {
throw std::runtime_error("failed to parse CLI arguments into preset");
}
return preset;
}
common_presets common_preset_context::cascade(const common_presets & base, const common_presets & added) const {
common_presets out = base; // copy
for (const auto & [name, preset_added] : added) {
if (out.find(name) != out.end()) {
// if exists, merge
common_preset & target = out[name];
target.merge(preset_added);
} else {
// otherwise, add directly
out[name] = preset_added;
}
}
return out;
}
common_presets common_preset_context::cascade(const common_preset & base, const common_presets & presets) const {
common_presets out;
for (const auto & [name, preset] : presets) {
common_preset tmp = base; // copy
tmp.name = name;
tmp.merge(preset);
out[name] = std::move(tmp);
}
return out;
}

View File

@ -13,20 +13,62 @@
constexpr const char * COMMON_PRESET_DEFAULT_NAME = "default";
struct common_preset_context;
struct common_preset {
std::string name;
// TODO: support repeated args in the future
// options are stored as common_arg to string mapping, representing CLI arg and its value
std::map<common_arg, std::string> options;
// convert preset to CLI argument list
std::vector<std::string> to_args() const;
std::vector<std::string> to_args(const std::string & bin_path = "") const;
// convert preset to INI format string
std::string to_ini() const;
// TODO: maybe implement to_env() if needed
// modify preset options where argument is identified by its env variable
void set_option(const common_preset_context & ctx, const std::string & env, const std::string & value);
// unset option by its env variable
void unset_option(const std::string & env);
// get option value by its env variable, return false if not found
bool get_option(const std::string & env, std::string & value) const;
// merge another preset into this one, overwriting existing options
void merge(const common_preset & other);
};
// interface for multiple presets in one file
using common_presets = std::map<std::string, common_preset>;
common_presets common_presets_load(const std::string & path, common_params_context & ctx_params);
// context for loading and editing presets
struct common_preset_context {
common_params default_params; // unused for now
common_params_context ctx_params;
std::map<std::string, common_arg> key_to_opt;
common_preset_context(llama_example ex);
// load presets from INI file
common_presets load_from_ini(const std::string & path, common_preset & global) const;
// generate presets from cached models
common_presets load_from_cache() const;
// generate presets from local models directory
// for the directory structure, see "Using multiple models" in server/README.md
common_presets load_from_models_dir(const std::string & models_dir) const;
// generate one preset from CLI arguments
common_preset load_from_args(int argc, char ** argv) const;
// cascade multiple presets if exist on both: base < added
// if preset does not exist in base, it will be added without modification
common_presets cascade(const common_presets & base, const common_presets & added) const;
// apply presets over a base preset (same idea as CSS cascading)
common_presets cascade(const common_preset & base, const common_presets & presets) const;
};

View File

@ -22,6 +22,7 @@
"GGML_LLAMAFILE": "OFF",
"GGML_OPENCL": "ON",
"GGML_HEXAGON": "ON",
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
"LLAMA_CURL": "OFF"
}
},
@ -36,6 +37,7 @@
"GGML_LLAMAFILE": "OFF",
"GGML_OPENCL": "ON",
"GGML_HEXAGON": "ON",
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
"LLAMA_CURL": "OFF"
}
},

View File

@ -55,7 +55,7 @@ auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder &
```
For a more complete example, see `test_example_native()` in
[tests/test-chat-peg-parser.cpp](tests/test-chat-peg-parser.cpp).
[tests/test-chat-peg-parser.cpp](/tests/test-chat-peg-parser.cpp).
## Parsers/Combinators
@ -175,7 +175,7 @@ Most model output can be placed in one of the following categories:
(Qwen3-Coder, MiniMax M2) or pseudo-function calls (LFM2)
To provide broad coverage,
[`common/chat-peg-parser.h`](common/chat-peg-parser.h) contains builders and
[`common/chat-peg-parser.h`](/common/chat-peg-parser.h) contains builders and
mappers that help create parsers and visitors/extractors for these types. They
require parsers to tag nodes to conform to an AST "shape". This normalization
makes it easy to extract information and generalize parsing.

View File

@ -254,6 +254,7 @@ set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING
"gmml: OpenCL API version to target")
option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF)
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml: quantize group size (32, 64, or 128)")
# toolchain for vulkan-shaders-gen
set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")

View File

@ -3089,8 +3089,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
@ -3098,7 +3101,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
@ -3107,8 +3114,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}

View File

@ -63,6 +63,9 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int id = ggml_cuda_get_device();
const int nsm = ggml_cuda_info().devices[id].nsm;
// Heuristic for block size selection to optimize occupancy.
// See discussion in: https://github.com/ggml-org/llama.cpp/pull/15132
if ((nrows / nsm) < 2) {
const dim3 block_dims(512, 1, 1);
reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);

View File

@ -268,7 +268,23 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
}
}
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) {
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
const ggml_tensor * weights,
const ggml_tensor * get_rows,
const ggml_tensor * argsort,
const ggml_tensor * clamp,
int n_expert) {
ggml_tensor * probs = get_rows->src[0];
if (probs->op != GGML_OP_RESHAPE) {
return false;
}
probs = probs->src[0];
ggml_tensor * selection_probs = argsort->src[0];
if (probs != selection_probs) {
return false;
}
float scale = 1.0f;
float max_bias = 0.0f;
@ -288,7 +304,6 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
return false;
}
const int n_expert = softmax->ne[0];
// n_expert must be a power of 2
if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
return false;

View File

@ -11,6 +11,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const bool delayed_softmax = false,
ggml_tensor * weight_clamp = nullptr);
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr);
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
const ggml_tensor * weights,
const ggml_tensor * get_rows,
const ggml_tensor * argsort,
const ggml_tensor * clamp,
int n_expert);
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);

View File

@ -2,6 +2,7 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
include(ExternalProject)
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)")
add_library(htp_iface OBJECT
${CMAKE_CURRENT_BINARY_DIR}/htp_iface_stub.c)
@ -41,7 +42,8 @@ set(HTP_CMAKE_ARGS
-DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR}
-DHEXAGON_SDK_ROOT=$ENV{HEXAGON_SDK_ROOT}
-DHEXAGON_TOOLS_ROOT=$ENV{HEXAGON_TOOLS_ROOT}
-DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG})
-DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}
-DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
ExternalProject_Add(htp-v68
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON

View File

@ -31,7 +31,8 @@ add_library(${HTP_LIB} SHARED
)
target_compile_definitions(${HTP_LIB} PRIVATE
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>)
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
build_idl(htp_iface.idl ${HTP_LIB})

View File

@ -92,6 +92,18 @@ static const uint8_t __attribute__((aligned(128))) repl_1x_fp16[128] = {
0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
};
// vdelta control to replicate first fp16 value across all elements
static const uint8_t __attribute__((aligned(128))) repl_2x_fp16[128] = {
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
};
// vdelta control to expand first 32 e8m0 values into 32 uint32 elements
static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00,
@ -1594,6 +1606,118 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0,
// *** dynamic quant
static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
assert((unsigned long) x % 128 == 0);
assert((unsigned long) y_q % 128 == 0);
HVX_Vector * vx = (HVX_Vector *) x;
HVX_Vector zero = Q6_V_vsplat_R(0);
// Use reduce max fp32 to find max(abs(e)) first
HVX_Vector vmax0_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[0]));
HVX_Vector vmax1_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[1]));
HVX_Vector vmax2_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[2]));
HVX_Vector vmax3_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[3]));
// Load and convert into QF32
HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
// Convert to QF32
HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
// Combine and convert to fp16
HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf)));
// Convert into fp16
HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
// Replicate first fp16 scale across all lanes
HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_fp16;
vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16);
hvx_vec_store_u(y_d + 0, 2, vd01_hf);
HVX_Vector rotated_vd_hf = Q6_V_vror_VR(vd01_hf, 64);
hvx_vec_store_u(y_d + 2, 2, rotated_vd_hf);
hvx_vec_store_u(y_d + 4, 2, vd23_hf);
rotated_vd_hf = Q6_V_vror_VR(vd23_hf, 64);
hvx_vec_store_u(y_d + 6, 2, rotated_vd_hf);
// Divide input by the scale
HVX_Vector vd01_inv_hf = hvx_vec_inverse_fp16(vd01_hf);
HVX_Vector vd23_inv_hf = hvx_vec_inverse_fp16(vd23_hf);
vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
// Convert to int8
HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
*(HVX_Vector *) y_q = vx_i8;
}
static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
assert((unsigned long) x % 128 == 0);
assert((unsigned long) y_q % 128 == 0);
HVX_Vector * vx = (HVX_Vector *) x;
// Load and convert into QF32
HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
// Convert into fp16
HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
// Compute max and scale
HVX_Vector vmax01_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx01_hf));
HVX_Vector vmax23_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx23_hf));
// Replicate first fp16 scale across all lanes
HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_fp16;
vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16);
hvx_vec_store_u(y_d + 0, 4, vd01_hf);
hvx_vec_store_u(y_d + 4, 4, vd23_hf);
// Divide input by the scale
HVX_Vector vd01_inv_hf = hvx_vec_inverse_fp16(vd01_hf);
HVX_Vector vd23_inv_hf = hvx_vec_inverse_fp16(vd23_hf);
vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
// Convert to int8
HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
*(HVX_Vector *) y_q = vx_i8;
}
static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
assert((unsigned long) x % 128 == 0);
assert((unsigned long) y_q % 128 == 0);
@ -1655,10 +1779,24 @@ static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, u
uint8_t * restrict t_d = (uint8_t *) x;
for (uint32_t i = 0; i < nb; i++) {
#if FP32_QUANTIZE_GROUP_SIZE == 32
quantize_block_fp32_q8x1(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2,
t_d + (i * 2 + 0) * dblk_size / 2);
quantize_block_fp32_q8x1(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2,
t_d + (i * 2 + 1) * dblk_size / 2);
#elif FP32_QUANTIZE_GROUP_SIZE == 64
quantize_block_fp32_q8x2(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2,
t_d + (i * 2 + 0) * dblk_size / 2);
quantize_block_fp32_q8x2(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2,
t_d + (i * 2 + 1) * dblk_size / 2);
#elif FP32_QUANTIZE_GROUP_SIZE == 128
quantize_block_fp32_q8x4(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2,
t_d + (i * 2 + 0) * dblk_size / 2);
quantize_block_fp32_q8x4(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2,
t_d + (i * 2 + 1) * dblk_size / 2);
#else
#error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128"
#endif
}
// now copy the scales into final location
@ -1671,6 +1809,7 @@ static void quantize_fp32_q8x4x2(const struct htp_tensor * src,
uint32_t nth,
uint32_t ith,
uint32_t nrows_per_thread) {
uint64_t t1 = HAP_perf_get_qtimer_count();
const uint32_t ne0 = src->ne[0];

View File

@ -583,7 +583,7 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
if (tensor->buffer) {
ggml_backend_buffer_t buffer = tensor->buffer;
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
result.buffer = ctx->remote_ptr;
result.buffer = ctx != nullptr ? ctx->remote_ptr : 0;
} else {
result.buffer = 0;
}

View File

@ -689,6 +689,7 @@ struct vk_device_struct {
vk_pipeline pipeline_gelu_quick[2];
vk_pipeline pipeline_silu[2];
vk_pipeline pipeline_relu[2];
vk_pipeline pipeline_xielu[2];
vk_pipeline pipeline_neg[2];
vk_pipeline pipeline_tanh[2];
vk_pipeline pipeline_sigmoid[2];
@ -855,6 +856,15 @@ struct vk_subbuffer {
}
};
// vk_event is used for the event-related backend interfaces. It uses 'event' for
// event_wait and 'fence' for event_synchronize. Polling on an event for
// event_synchronize wouldn't be sufficient to wait for command buffers to complete,
// and would lead to validation errors.
struct vk_event {
vk::Event event;
vk::Fence fence;
};
struct vk_semaphore {
vk::Semaphore s;
uint64_t value;
@ -990,6 +1000,8 @@ struct vk_op_push_constants {
uint32_t KY;
float param1;
float param2;
float param3;
float param4;
};
struct vk_op_glu_push_constants {
@ -1258,6 +1270,7 @@ struct vk_op_im2col_push_constants {
int32_t s0; int32_t s1;
int32_t p0; int32_t p1;
int32_t d0; int32_t d1;
uint32_t batch_IC;
};
struct vk_op_im2col_3d_push_constants {
@ -2540,6 +2553,15 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct
);
}
static void ggml_vk_set_event(vk_context& ctx, vk::Event& event) {
VK_LOG_DEBUG("ggml_vk_set_event()");
ctx->s->buffer.setEvent(
event,
ctx->p->q->stage_flags
);
}
static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events) {
VK_LOG_DEBUG("ggml_vk_wait_events()");
if (events.empty()) {
@ -3973,6 +3995,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_UNARY(gelu_quick)
CREATE_UNARY(silu)
CREATE_UNARY(relu)
CREATE_UNARY(xielu)
CREATE_UNARY(neg)
CREATE_UNARY(tanh)
CREATE_UNARY(sigmoid)
@ -5898,6 +5921,9 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
}
std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
GGML_ASSERT(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
@ -6081,13 +6107,8 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
}
}
static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) {
static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) {
VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")");
// Buffer is already mapped
if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
std::cerr << "ggml_vulkan: buffer_write_async dst buffer is host_visible. Use synchronous write." << std::endl;
GGML_ABORT("fatal error");
}
// Check if src is pinned memory
vk_buffer buf = nullptr;
size_t buf_offset = 0;
@ -6112,12 +6133,13 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
ggml_vk_sync_buffers(nullptr, subctx);
subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
return;
return true;
}
VK_LOG_DEBUG("STAGING");
if (!sync_staging) {
GGML_ABORT("Asynchronous write to non-pinned memory not supported");
// copy was not handled caller needs to fall back
return false;
}
// Staging buffer required
@ -6141,9 +6163,10 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys);
}
}
return true;
}
static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) {
static bool ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) {
VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")");
return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging);
}
@ -6162,7 +6185,8 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(dst->device, subctx);
ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true);
bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true);
GGML_ASSERT(ret);
ggml_vk_ctx_end(subctx);
for (auto& cpy : subctx->in_memcpys) {
@ -8549,6 +8573,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_RELU:
return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_XIELU:
return ctx->device->pipeline_xielu[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_NEG:
return ctx->device->pipeline_neg[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_TANH:
@ -9084,6 +9110,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
elements = { OW * KW * KH, OH, batch * IC };
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
} break;
case GGML_OP_IM2COL_3D:
{
@ -9695,14 +9723,14 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
ggml_vk_op_f32_opt_step_adamw(
ctx, subctx, dst,
{ (uint32_t)n, 0, 0.0f, 0.0f }
{ (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f }
);
}
static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
const size_t n = ggml_nelements(dst->src[0]);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f });
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f });
}
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@ -9788,6 +9816,7 @@ static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, gg
1,
ggml_get_op_params_f32(dst, 0),
ggml_get_op_params_f32(dst, 2),
0.0f, 0.0f,
};
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_ARANGE);
@ -9809,6 +9838,7 @@ static void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml
1,
ggml_get_op_params_f32(dst, 0),
0.0f,
0.0f, 0.0f,
};
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_FILL);
@ -9924,13 +9954,13 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
}
static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
}
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
}
static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@ -9941,7 +9971,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
const float eps = float_op_params[1];
const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f });
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f, 0.0f, 0.0f });
}
static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
@ -10110,16 +10140,26 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
}
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
}
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
}
static void ggml_vk_xielu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY,
{
(uint32_t)ggml_nelements(src0), 0,
op_params[1], op_params[2], op_params[3], op_params[4]
}
);
}
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@ -10244,7 +10284,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] });
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1], 0.0f, 0.0f });
}
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
@ -10541,11 +10581,11 @@ static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, co
}
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f, 0.0f, 0.0f });
}
static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
}
static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@ -10587,6 +10627,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
const uint32_t pelements = OW * KW * KH;
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
const vk_buffer d_buf = d_buf_ctx->dev_buffer;
@ -10599,7 +10640,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
IC, IW, IH, OW, OH, KW, KH,
pelements,
IC * KH * KW,
s0, s1, p0, p1, d0, d1,
s0, s1, p0, p1, d0, d1, batch * IC
});
}
@ -10804,7 +10845,7 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
const float * op_params = (const float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f });
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f });
}
#ifdef GGML_VULKAN_RUN_TESTS
@ -12050,6 +12091,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_UNARY_OP_TRUNC:
ggml_vk_unary(ctx, compute_ctx, src0, node);
break;
case GGML_UNARY_OP_XIELU:
ggml_vk_xielu(ctx, compute_ctx, src0, node);
break;
default:
return false;
}
@ -12643,7 +12687,23 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
vk_buffer buf = buf_ctx->dev_buffer;
ggml_vk_buffer_write_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
bool ret = ggml_vk_buffer_write_async(transfer_ctx, buf, dst_offset, data, size);
if (!ret) {
ggml_vk_ensure_sync_staging_buffer(ctx, size);
ggml_vk_sync_buffers(nullptr, transfer_ctx);
vk::BufferCopy buffer_cpy;
buffer_cpy.srcOffset = 0;
buffer_cpy.dstOffset = dst_offset;
buffer_cpy.size = size;
transfer_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
deferred_memcpy(ctx->sync_staging->ptr, data, size, &transfer_ctx->in_memcpys);
ggml_vk_synchronize(ctx);
}
}
static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
@ -12920,24 +12980,43 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
const ggml_tensor * softmax;
const ggml_tensor * weights;
const ggml_tensor * get_rows;
const ggml_tensor * argsort;
switch (mode) {
case TOPK_MOE_EARLY_SOFTMAX_NORM:
softmax = cgraph->nodes[node_idx + 0];
weights = cgraph->nodes[node_idx + 9];
get_rows = cgraph->nodes[node_idx + 4];
argsort = cgraph->nodes[node_idx + 2];
break;
case TOPK_MOE_EARLY_SOFTMAX:
softmax = cgraph->nodes[node_idx + 0];
weights = cgraph->nodes[node_idx + 4];
get_rows = cgraph->nodes[node_idx + 4];
argsort = cgraph->nodes[node_idx + 2];
break;
case TOPK_MOE_LATE_SOFTMAX:
softmax = cgraph->nodes[node_idx + 4];
weights = cgraph->nodes[node_idx + 5];
get_rows = cgraph->nodes[node_idx + 2];
argsort = cgraph->nodes[node_idx + 0];
break;
default:
return false;
}
ggml_tensor * probs = get_rows->src[0];
if (probs->op != GGML_OP_RESHAPE) {
return false;
}
probs = probs->src[0];
ggml_tensor * selection_probs = argsort->src[0];
if (probs != selection_probs) {
return false;
}
const float * op_params = (const float *)softmax->op_params;
float scale = op_params[0];
@ -13502,7 +13581,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL) &&
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL)) {
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) &&
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) {
ok = false;
break;
}
@ -13630,11 +13710,58 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
}
}
static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
vk_event *vkev = (vk_event *)event->context;
vk_context transfer_ctx;
if (ctx->transfer_ctx.expired()) {
// Initialize new transfer context
transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
ctx->transfer_ctx = transfer_ctx;
ggml_vk_ctx_begin(ctx->device, transfer_ctx);
} else {
transfer_ctx = ctx->transfer_ctx.lock();
}
// the backend interface doesn't have an explicit reset, so reset it here
// before we record the command to set it
ctx->device->device.resetEvent(vkev->event);
ctx->device->device.resetFences({ vkev->fence });
ggml_vk_set_event(transfer_ctx, vkev->event);
ggml_vk_ctx_end(transfer_ctx);
ggml_vk_submit(transfer_ctx, {vkev->fence});
ctx->submit_pending = true;
ctx->transfer_ctx.reset();
}
static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
vk_event *vkev = (vk_event *)event->context;
vk_context transfer_ctx;
if (ctx->transfer_ctx.expired()) {
// Initialize new transfer context
transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
ctx->transfer_ctx = transfer_ctx;
ggml_vk_ctx_begin(ctx->device, transfer_ctx);
} else {
transfer_ctx = ctx->transfer_ctx.lock();
}
ggml_vk_wait_events(transfer_ctx, {vkev->event});
}
// TODO: enable async and synchronize
static ggml_backend_i ggml_backend_vk_interface = {
/* .get_name = */ ggml_backend_vk_name,
/* .free = */ ggml_backend_vk_free,
/* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async,
/* .set_tensor_async = */ ggml_backend_vk_set_tensor_async,
/* .get_tensor_async = */ ggml_backend_vk_get_tensor_async,
/* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async,
/* .synchronize = */ ggml_backend_vk_synchronize,
@ -13643,8 +13770,8 @@ static ggml_backend_i ggml_backend_vk_interface = {
/* .graph_plan_update = */ NULL,
/* .graph_plan_compute = */ NULL,
/* .graph_compute = */ ggml_backend_vk_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .event_record = */ ggml_backend_vk_event_record,
/* .event_wait = */ ggml_backend_vk_event_wait,
/* .graph_optimize = */ ggml_vk_graph_optimize,
};
@ -13819,10 +13946,10 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
props->caps = {
/* .async = */ false,
/* .async = */ true,
/* .host_buffer = */ true,
/* .buffer_from_host_ptr = */ false,
/* .events = */ false,
/* .events = */ true,
};
}
@ -13842,6 +13969,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_XIELU:
case GGML_UNARY_OP_NEG:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_SIGMOID:
@ -14353,6 +14481,46 @@ static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml
UNUSED(dev);
}
static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t dev) {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
vk_event *vkev = new vk_event;
if (!vkev) {
return nullptr;
}
// The event/fence is expected to initially be in the signaled state.
vkev->event = device->device.createEvent({});
vkev->fence = device->device.createFence({vk::FenceCreateFlagBits::eSignaled});
device->device.setEvent(vkev->event);
return new ggml_backend_event {
/* .device = */ dev,
/* .context = */ vkev,
};
}
static void ggml_backend_vk_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
vk_event *vkev = (vk_event *)event->context;
device->device.destroyFence(vkev->fence);
device->device.destroyEvent(vkev->event);
delete vkev;
delete event;
}
static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
vk_event *vkev = (vk_event *)event->context;
VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize");
}
static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
/* .get_name = */ ggml_backend_vk_device_get_name,
/* .get_description = */ ggml_backend_vk_device_get_description,
@ -14366,9 +14534,9 @@ static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
/* .supports_op = */ ggml_backend_vk_device_supports_op,
/* .supports_buft = */ ggml_backend_vk_device_supports_buft,
/* .offload_op = */ ggml_backend_vk_device_offload_op,
/* .event_new = */ NULL,
/* .event_free = */ NULL,
/* .event_synchronize = */ NULL,
/* .event_new = */ ggml_backend_vk_device_event_new,
/* .event_free = */ ggml_backend_vk_device_event_free,
/* .event_synchronize = */ ggml_backend_vk_device_event_synchronize,
};
static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) {
@ -14747,7 +14915,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
} else if (tensor->op == GGML_OP_LOG) {
tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_TRI) {
tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
tensor_clone = ggml_tri(ggml_ctx, src_clone[0], (ggml_tri_type)ggml_get_op_params_i32(tensor, 0));
} else if (tensor->op == GGML_OP_DIAG) {
tensor_clone = ggml_diag(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_CLAMP) {
@ -14835,6 +15003,13 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_UNARY_OP_RELU:
tensor_clone = ggml_relu(ggml_ctx, src_clone[0]);
break;
case GGML_UNARY_OP_XIELU:
tensor_clone = ggml_xielu(ggml_ctx, src_clone[0], 0, 0, 0, 0);
ggml_set_op_params_f32(tensor_clone, 1, ggml_get_op_params_f32(tensor, 1));
ggml_set_op_params_f32(tensor_clone, 2, ggml_get_op_params_f32(tensor, 2));
ggml_set_op_params_f32(tensor_clone, 3, ggml_get_op_params_f32(tensor, 3));
ggml_set_op_params_f32(tensor_clone, 4, ggml_get_op_params_f32(tensor, 4));
break;
case GGML_UNARY_OP_NEG:
tensor_clone = ggml_neg(ggml_ctx, src_clone[0]);
break;

View File

@ -6,4 +6,6 @@ layout (push_constant) uniform parameter
uint KY;
float param1;
float param2;
float param3;
float param4;
} p;

View File

@ -19,6 +19,7 @@ layout (push_constant) uniform parameter
int s0; int s1;
int p0; int p1;
int d0; int d1;
uint batch_IC;
} p;
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
@ -34,12 +35,12 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
#endif
void main() {
void im2col(const uint y, const uint z) {
const uint gidx = gl_GlobalInvocationID.x;
const uint oh = gl_GlobalInvocationID.y;
const uint batch = gl_GlobalInvocationID.z / p.IC;
const uint ic = gl_GlobalInvocationID.z % p.IC;
const uint oh = y;
const uint batch = z / p.IC;
const uint ic = z % p.IC;
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
@ -101,3 +102,15 @@ void main() {
#endif
}
}
void main() {
uint y = gl_GlobalInvocationID.y;
while (y < p.OH) {
uint z = gl_GlobalInvocationID.z;
while (z < p.batch_IC) {
im2col(y, z);
z += gl_NumWorkGroups.z;
}
y += gl_NumWorkGroups.y;
}
}

View File

@ -11,36 +11,54 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const uint y_idx = i * QUANT_K + 16 * itid;
const uint nibble_shift = 4 * (itid & 1);
const uint ib32 = itid / 2; // 0..7
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
// Precompute db multiplication factors
float db_vals[NUM_ROWS];
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF;
const float db = d * (0.5 + scale) * 0.25;
const uint scale_raw = data_a[ibi].scales[ib32];
const uint scale = (scale_raw >> nibble_shift) & 0xF;
// Merge constant calculations d * (0.5 + scale) * 0.25 = d*0.125 + d*scale*0.25
db_vals[n] = d * (0.125f + float(scale) * 0.25f);
ibi += num_blocks_per_row;
}
ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
// Preload grid and sign data for all l values
vec4 grid0_vals[2], grid1_vals[2];
uint sign_vals[2], sign7_vals[2];
[[unroll]] for (uint l = 0; l < 2; ++l) {
const uint qs = data_a[ibi].qs[2 * itid + l];
const uint sign = qs >> 9;
const uint sign7 = bitCount(sign);
const vec4 grid0 = vec4(unpack8(iq2xs_grid[qs & 511].x));
const vec4 grid1 = vec4(unpack8(iq2xs_grid[qs & 511].y));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
FLOAT_TYPE sum =
fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w),
FLOAT_TYPE(0.0)))))))));
temp[j][n] = fma(db, sum, temp[j][n]);
sign_vals[l] = qs >> 9;
sign7_vals[l] = bitCount(sign_vals[l]);
const uvec2 grid_data = iq2xs_grid[qs & 511];
grid0_vals[l] = vec4(unpack8(grid_data.x));
grid1_vals[l] = vec4(unpack8(grid_data.y));
}
// Preload B data for all j columns (reduce repeated index calculations)
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (uint l = 0; l < 2; ++l) {
const uint sign = sign_vals[l];
const uint sign7 = sign7_vals[l];
const vec4 grid0 = grid0_vals[l];
const vec4 grid1 = grid1_vals[l];
// Precompute indices
const uint b_idx = (j * p.batch_stride_b + b_offset + y_idx) / 4 + 2 * l;
const vec4 b0 = vec4(data_b_v4[b_idx + 0]);
const vec4 b4 = vec4(data_b_v4[b_idx + 1]);
sum +=
fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w),
FLOAT_TYPE(0.0)))))))));
}
temp[j][n] = fma(FLOAT_TYPE(db_vals[n]), sum, temp[j][n]);
}
ibi += num_blocks_per_row;
}

View File

@ -853,6 +853,8 @@ void process_shaders() {
string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("xielu_f16", "xielu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("xielu_f32", "xielu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});

View File

@ -0,0 +1,35 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
float x = float(data_a[i]);
float alpha_n = p.param1;
float alpha_p = p.param2;
float beta = p.param3;
float eps = p.param4;
if (x > 0.0f) {
x = alpha_p * x * x + beta * x;
} else {
const float min_x_eps = min(x, eps);
x = (exp(min_x_eps) - 1 - x) * alpha_n + beta * x;
}
data_d[i] = D_TYPE(x);
}

View File

@ -1086,10 +1086,10 @@ bool llama_model_loader::load_all_data(
} else {
// If upload_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU.
if (upload_backend) {
auto offset = (off_t) weight->offs;
size_t offset = weight->offs;
alignment = file->read_alignment();
off_t aligned_offset = offset & ~(alignment - 1);
off_t offset_from_alignment = offset - aligned_offset;
size_t aligned_offset = offset & ~(alignment - 1);
size_t offset_from_alignment = offset - aligned_offset;
file->seek(aligned_offset, SEEK_SET);
// Calculate aligned read boundaries

View File

@ -16,6 +16,7 @@ int main(void) {
for (int ex = 0; ex < LLAMA_EXAMPLE_COUNT; ex++) {
try {
auto ctx_arg = common_params_parser_init(params, (enum llama_example)ex);
common_params_add_preset_options(ctx_arg.options);
std::unordered_set<std::string> seen_args;
std::unordered_set<std::string> seen_env_vars;
for (const auto & opt : ctx_arg.options) {
@ -37,6 +38,30 @@ int main(void) {
exit(1);
}
}
// ensure shorter argument precedes longer argument
if (opt.args.size() > 1) {
const std::string first(opt.args.front());
const std::string last(opt.args.back());
if (first.length() > last.length()) {
fprintf(stderr, "test-arg-parser: shorter argument should come before longer one: %s, %s\n",
first.c_str(), last.c_str());
assert(false);
}
}
// same check for negated arguments
if (opt.args_neg.size() > 1) {
const std::string first(opt.args_neg.front());
const std::string last(opt.args_neg.back());
if (first.length() > last.length()) {
fprintf(stderr, "test-arg-parser: shorter negated argument should come before longer one: %s, %s\n",
first.c_str(), last.c_str());
assert(false);
}
}
}
} catch (std::exception & e) {
printf("%s\n", e.what());

View File

@ -5118,25 +5118,36 @@ struct test_top_k : public test_case {
}
};
enum MoeGatingFunc {
GATING_FUNC_SOFTMAX,
GATING_FUNC_SIGMOID,
GATING_FUNC_SOFTMAX_WEIGHT,
};
struct test_topk_moe : public test_case {
const std::array<int64_t, 4> ne;
const int n_expert_used;
const bool with_norm;
const bool delayed_softmax;
const bool bias_probs;
const MoeGatingFunc gating_func;
const float scale_w;
test_topk_moe(std::array<int64_t, 4> ne = { 10, 5, 1, 1 },
int n_expert_used = 1,
bool with_norm = false,
bool delayed_softmax = false) :
bool bias_probs = false,
MoeGatingFunc gating_func = GATING_FUNC_SOFTMAX,
float scale_w = 0.0f) :
ne(ne),
n_expert_used(n_expert_used),
with_norm(with_norm),
delayed_softmax(delayed_softmax) {
bias_probs(bias_probs),
gating_func(gating_func),
scale_w(scale_w) {
GGML_ASSERT(n_expert_used <= ne[0]);
GGML_ASSERT(!(with_norm && delayed_softmax));
}
std::string vars() override { return VARS_TO_STR4(ne, n_expert_used, with_norm, delayed_softmax); }
std::string vars() override { return VARS_TO_STR6(ne, n_expert_used, with_norm, bias_probs, gating_func, scale_w); }
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
@ -5150,28 +5161,47 @@ struct test_topk_moe : public test_case {
const int n_tokens = ne[1];
ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits);
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
ggml_tensor * probs =
(gating_func == GATING_FUNC_SOFTMAX) ? ggml_soft_max(ctx, logits) :
(gating_func == GATING_FUNC_SIGMOID) ? ggml_sigmoid(ctx, logits) : logits;
ggml_set_name(probs, "probs");
ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
ggml_tensor * selection_probs = probs;
if (bias_probs) {
ggml_tensor * exp_probs_b = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
ggml_set_name(exp_probs_b, "exp_probs_b");
selection_probs = ggml_add(ctx, probs, exp_probs_b);
ggml_set_name(selection_probs, "selection_probs");
}
if (delayed_softmax) {
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
out = ggml_soft_max(ctx, out); // [n_expert_used, n_tokens]
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
ggml_set_name(selected_experts, "selected_experts");
ggml_tensor * weights = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
ggml_set_name(weights, "weights");
if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
weights = ggml_soft_max(ctx, weights); // [n_expert_used, n_tokens]
weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
}
if (with_norm) {
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens]
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens]
ggml_set_name(weights_sum, "weights_sum");
weights_sum = ggml_clamp(ctx, weights_sum, 6.103515625e-5, INFINITY);
out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens]
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens]
weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
}
ggml_set_name(out, "out");
return out;
if (scale_w) {
weights = ggml_scale(ctx, weights, scale_w);
}
ggml_set_name(weights, "weights");
return weights;
}
};
@ -5344,6 +5374,13 @@ struct test_sum : public test_case {
float grad_eps() override {
return 0.1f * sqrtf(ne[0]*ne[1]*ne[2]*ne[3]);
}
// Don't center the distribution around zero. Helps to avoid catastrophic cancellation.
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
init_tensor_uniform(t, -0.9f, 1.1f);
}
}
};
// GGML_OP_SUM_ROWS
@ -5410,6 +5447,13 @@ struct test_mean : public test_case {
float grad_eps() override {
return 0.1f * ne[0]*ne[1]*ne[2]*ne[3];
}
// Don't center the distribution around zero. Helps to avoid catastrophic cancellation.
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
init_tensor_uniform(t, -0.9f, 1.1f);
}
}
};
// GGML_OP_UPSCALE
@ -6710,6 +6754,11 @@ static const ggml_type other_types[] = {
GGML_TYPE_BF16,
};
#ifdef _MSC_VER
// Workaround long compile time with msvc
#pragma optimize("", off)
#endif
// Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low
static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
std::vector<std::unique_ptr<test_case>> test_cases;
@ -6881,6 +6930,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {5, 5, 1, 32}, {3, 4, 1, 32}, 1, 1, 0, 0, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {2, 2, 1536, 729}, {2, 2, 1536, 4096}, 1, 1, 0, 0, 1, 1, true));
// im2col 3D
test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
@ -7975,19 +8025,22 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
for (bool with_norm : {false, true}) {
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm));
test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm));
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm));
test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm));
test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm));
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm));
test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm));
for (auto gate : {GATING_FUNC_SOFTMAX, GATING_FUNC_SIGMOID, GATING_FUNC_SOFTMAX_WEIGHT}) {
for (bool with_norm : {false, true}) {
for (bool bias_probs : {false, true}) {
for (float scale_w : {0.0f, 2.0f}) {
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
}
}
}
}
test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true));
test_cases.emplace_back(new test_topk_moe({ 32, 22, 1, 1 }, 8, /*with_norm*/ false, /*delayed_softmax*/ true));
#if 0
// these tests are disabled to save execution time, sbut they can be handy for debugging
test_cases.emplace_back(new test_llama(2, true));
@ -7999,6 +8052,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
return test_cases;
}
#ifdef _MSC_VER
#pragma optimize("", on)
#endif
// Test cases for performance evaluation: should be representative of real-world use cases
static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {

View File

@ -209,8 +209,6 @@ int main(int argc, char ** argv) {
return 1;
}
ctx_cli.ctx_server.init();
console::spinner::stop();
console::log("\n");

View File

@ -107,6 +107,8 @@ For detailed instructions, see the [test documentation](./tests/README.md).
- Large-scale code base split into smaller files: https://github.com/ggml-org/llama.cpp/pull/17362
- Introduction of router mode: https://github.com/ggml-org/llama.cpp/pull/17470
- Speculative decoding: https://github.com/ggml-org/llama.cpp/pull/17808 and rework in https://github.com/ggml-org/llama.cpp/pull/17808
- INI presets: https://github.com/ggml-org/llama.cpp/pull/17859 (+ refactoring: https://github.com/ggml-org/llama.cpp/pull/18169)
- Sleeping mode: https://github.com/ggml-org/llama.cpp/pull/18228

View File

@ -75,9 +75,9 @@ For the ful list of features, please refer to [server's changelog](https://githu
| `--numa TYPE` | attempt optimizations that help on some NUMA systems<br/>- distribute: spread execution evenly over all nodes<br/>- isolate: only spawn threads on CPUs on the node that execution started on<br/>- numactl: use the CPU map provided by numactl<br/>if run without this previously, it is recommended to drop the system page cache before using this<br/>see https://github.com/ggml-org/llama.cpp/issues/1437<br/>(env: LLAMA_ARG_NUMA) |
| `-dev, --device <dev1,dev2,..>` | comma-separated list of devices to use for offloading (none = don't offload)<br/>use --list-devices to see a list of available devices<br/>(env: LLAMA_ARG_DEVICE) |
| `--list-devices` | print list of available devices and exit |
| `--override-tensor, -ot <tensor name pattern>=<buffer type>,...` | override tensor buffer type |
| `--cpu-moe, -cmoe` | keep all Mixture of Experts (MoE) weights in the CPU<br/>(env: LLAMA_ARG_CPU_MOE) |
| `--n-cpu-moe, -ncmoe N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU<br/>(env: LLAMA_ARG_N_CPU_MOE) |
| `-ot, --override-tensor <tensor name pattern>=<buffer type>,...` | override tensor buffer type |
| `-cmoe, --cpu-moe` | keep all Mixture of Experts (MoE) weights in the CPU<br/>(env: LLAMA_ARG_CPU_MOE) |
| `-ncmoe, --n-cpu-moe N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU<br/>(env: LLAMA_ARG_N_CPU_MOE) |
| `-ngl, --gpu-layers, --n-gpu-layers N` | max. number of layers to store in VRAM (default: -1)<br/>(env: LLAMA_ARG_N_GPU_LAYERS) |
| `-sm, --split-mode {none,layer,row}` | how to split the model across multiple GPUs, one of:<br/>- none: use one GPU only<br/>- layer (default): split layers and KV across GPUs<br/>- row: split rows across GPUs<br/>(env: LLAMA_ARG_SPLIT_MODE) |
| `-ts, --tensor-split N0,N1,N2,...` | fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1<br/>(env: LLAMA_ARG_TENSOR_SPLIT) |
@ -120,7 +120,7 @@ For the ful list of features, please refer to [server's changelog](https://githu
| -------- | ----------- |
| `--samplers SAMPLERS` | samplers that will be used for generation in the order, separated by ';'<br/>(default: penalties;dry;top_n_sigma;top_k;typ_p;top_p;min_p;xtc;temperature) |
| `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) |
| `--sampling-seq, --sampler-seq SEQUENCE` | simplified sequence for samplers that will be used (default: edskypmxt) |
| `--sampler-seq, --sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: edskypmxt) |
| `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) |
| `--temp N` | temperature (default: 0.8) |
| `--top-k N` | top-k sampling (default: 40, 0 = disabled)<br/>(env: LLAMA_ARG_TOP_K) |
@ -156,8 +156,8 @@ For the ful list of features, please refer to [server's changelog](https://githu
| Argument | Explanation |
| -------- | ----------- |
| `--ctx-checkpoints, --swa-checkpoints N` | max number of context checkpoints to create per slot (default: 8)[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)<br/>(env: LLAMA_ARG_CTX_CHECKPOINTS) |
| `--cache-ram, -cram N` | set the maximum cache size in MiB (default: 8192, -1 - no limit, 0 - disable)[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)<br/>(env: LLAMA_ARG_CACHE_RAM) |
| `--kv-unified, -kvu` | use single unified KV buffer shared across all sequences (default: enabled if number of slots is auto)<br/>(env: LLAMA_ARG_KV_UNIFIED) |
| `-cram, --cache-ram N` | set the maximum cache size in MiB (default: 8192, -1 - no limit, 0 - disable)[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)<br/>(env: LLAMA_ARG_CACHE_RAM) |
| `-kvu, --kv-unified` | use single unified KV buffer shared across all sequences (default: enabled if number of slots is auto)<br/>(env: LLAMA_ARG_KV_UNIFIED) |
| `--context-shift, --no-context-shift` | whether to use context shift on infinite text generation (default: disabled)<br/>(env: LLAMA_ARG_CONTEXT_SHIFT) |
| `-r, --reverse-prompt PROMPT` | halt generation at PROMPT, return control in interactive mode<br/> |
| `-sp, --special` | special tokens output enabled (default: false) |
@ -172,9 +172,9 @@ For the ful list of features, please refer to [server's changelog](https://githu
| `--mmproj-offload, --no-mmproj-offload` | whether to enable GPU offloading for multimodal projector (default: enabled)<br/>(env: LLAMA_ARG_MMPROJ_OFFLOAD) |
| `--image-min-tokens N` | minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)<br/>(env: LLAMA_ARG_IMAGE_MIN_TOKENS) |
| `--image-max-tokens N` | maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)<br/>(env: LLAMA_ARG_IMAGE_MAX_TOKENS) |
| `--override-tensor-draft, -otd <tensor name pattern>=<buffer type>,...` | override tensor buffer type for draft model |
| `--cpu-moe-draft, -cmoed` | keep all Mixture of Experts (MoE) weights in the CPU for the draft model<br/>(env: LLAMA_ARG_CPU_MOE_DRAFT) |
| `--n-cpu-moe-draft, -ncmoed N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model<br/>(env: LLAMA_ARG_N_CPU_MOE_DRAFT) |
| `-otd, --override-tensor-draft <tensor name pattern>=<buffer type>,...` | override tensor buffer type for draft model |
| `-cmoed, --cpu-moe-draft` | keep all Mixture of Experts (MoE) weights in the CPU for the draft model<br/>(env: LLAMA_ARG_CPU_MOE_DRAFT) |
| `-ncmoed, --n-cpu-moe-draft N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model<br/>(env: LLAMA_ARG_N_CPU_MOE_DRAFT) |
| `-a, --alias STRING` | set alias for model name (to be used by REST API)<br/>(env: LLAMA_ARG_ALIAS) |
| `--host HOST` | ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: 127.0.0.1)<br/>(env: LLAMA_ARG_HOST) |
| `--port PORT` | port to listen (default: 8080)<br/>(env: LLAMA_ARG_PORT) |
@ -184,7 +184,7 @@ For the ful list of features, please refer to [server's changelog](https://githu
| `--webui-config-file PATH` | JSON file that provides default WebUI settings (overrides WebUI defaults)<br/>(env: LLAMA_ARG_WEBUI_CONFIG_FILE) |
| `--webui, --no-webui` | whether to enable the Web UI (default: enabled)<br/>(env: LLAMA_ARG_WEBUI) |
| `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)<br/>(env: LLAMA_ARG_EMBEDDINGS) |
| `--reranking, --rerank` | enable reranking endpoint on server (default: disabled)<br/>(env: LLAMA_ARG_RERANKING) |
| `--rerank, --reranking` | enable reranking endpoint on server (default: disabled)<br/>(env: LLAMA_ARG_RERANKING) |
| `--api-key KEY` | API key to use for authentication (default: none)<br/>(env: LLAMA_API_KEY) |
| `--api-key-file FNAME` | path to file containing API keys (default: none) |
| `--ssl-key-file FNAME` | path to file a PEM-encoded SSL private key<br/>(env: LLAMA_ARG_SSL_KEY_FILE) |
@ -212,7 +212,7 @@ For the ful list of features, please refer to [server's changelog](https://githu
| `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) |
| `-td, --threads-draft N` | number of threads to use during generation (default: same as --threads) |
| `-tbd, --threads-batch-draft N` | number of threads to use during batch and prompt processing (default: same as --threads-draft) |
| `--draft-max, --draft, --draft-n N` | number of tokens to draft for speculative decoding (default: 16)<br/>(env: LLAMA_ARG_DRAFT_MAX) |
| `--draft, --draft-n, --draft-max N` | number of tokens to draft for speculative decoding (default: 16)<br/>(env: LLAMA_ARG_DRAFT_MAX) |
| `--draft-min, --draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)<br/>(env: LLAMA_ARG_DRAFT_MIN) |
| `--draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.8)<br/>(env: LLAMA_ARG_DRAFT_P_MIN) |
| `-cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)<br/>(env: LLAMA_ARG_CTX_SIZE_DRAFT) |
@ -1443,6 +1443,12 @@ Example:
```ini
version = 1
; (Optional) This section provides global settings shared across all presets.
; If the same key is defined in a specific preset, it will override the value in this global section.
[*]
c = 8192
n-gpu-layer = 8
; If the key corresponds to an existing model on the server,
; this will be used as the default config for that model
[ggml-org/MY-MODEL-GGUF:Q8_0]
@ -1462,12 +1468,20 @@ model-draft = ./my-models/draft.gguf
model-draft = /Users/abc/my-models/draft.gguf
; If the key does NOT correspond to an existing model,
; you need to specify at least the model path
; you need to specify at least the model path or HF repo
[custom_model]
model = /Users/abc/my-awesome-model-Q4_K_M.gguf
```
Note: some arguments are controlled by router (e.g., host, port, API key, HF repo, model alias). They will be removed or overwritten upload loading.
Note: some arguments are controlled by router (e.g., host, port, API key, HF repo, model alias). They will be removed or overwritten upon loading.
The precedence rule for preset options is as follows:
1. **Command-line arguments** passed to `llama-server` (highest priority)
2. **Model-specific options** defined in the preset file (e.g. `[ggml-org/MY-MODEL...]`)
3. **Global options** defined in the preset file (`[*]`)
We also offer additional options that are exclusive to presets (these aren't treated as command-line arguments):
- `load-on-startup` (boolean): Controls whether the model loads automatically when the server starts
### Routing requests
@ -1607,6 +1621,16 @@ Example of an error:
}
```
## Sleeping on Idle
The server supports an automatic sleep mode that activates after a specified period of inactivity (no incoming tasks). This feature, introduced in [PR #18228](https://github.com/ggml-org/llama.cpp/pull/18228), can be enabled using the `--sleep-idle-seconds` command-line argument. It works seamlessly in both single-model and multi-model configurations.
When the server enters sleep mode, the model and its associated memory (including the KV cache) are unloaded from RAM to conserve resources. Any new incoming task will automatically trigger the model to reload.
Note that the following endpoints are exempt from being considered as incoming tasks. They do not trigger model reloading and do not reset the idle timer:
- `GET /health`
- `GET /props`
## More examples
### Interactive mode

Binary file not shown.

View File

@ -544,7 +544,9 @@ struct server_context_impl {
server_metrics metrics;
json webui_settings = json::object();
// cached responses for HTTP API (read-only from HTTP threads)
json json_server_props = json::object();
json json_server_model_meta = json::object();
// Necessary similarity of prompt for slot selection
float slot_prompt_similarity = 0.0f;
@ -554,8 +556,23 @@ struct server_context_impl {
common_chat_templates_ptr chat_templates;
oaicompat_parser_options oai_parser_opt;
bool sleeping = false;
~server_context_impl() {
if (!sleeping) {
// destroy() is already called when entering sleeping state
// we don't call it again here to avoid double free
destroy();
}
}
void destroy() {
llama_init.reset();
ctx = nullptr;
model = nullptr;
mtmd_free(mctx);
mctx = nullptr;
// Clear any sampling context
for (server_slot & slot : slots) {
@ -571,22 +588,29 @@ struct server_context_impl {
llama_batch_free(batch);
}
void handle_sleeping_state(bool new_state) {
GGML_ASSERT(sleeping != new_state);
if (new_state) {
SRV_INF("%s", "server is entering sleeping state\n");
destroy();
} else {
SRV_INF("%s", "server is exiting sleeping state\n");
if (!load_model(params_base)) {
GGML_ABORT("failed to reload model after sleeping");
}
}
sleeping = new_state;
}
// load the model and initialize llama_context
// this may also be called to resume from sleeping state
bool load_model(const common_params & params) {
bool is_resume = sleeping;
SRV_INF("loading model '%s'\n", params.model.path.c_str());
params_base = params;
webui_settings = json::object();
if (!params_base.webui_config_json.empty()) {
try {
webui_settings = json::parse(params_base.webui_config_json);
} catch (const std::exception & e) {
SRV_ERR("%s: failed to parse webui config: %s\n", __func__, e.what());
return false;
}
}
llama_init = common_init_from_params(params_base);
model = llama_init->model();
@ -654,7 +678,9 @@ struct server_context_impl {
std::string & mmproj_path = params_base.mmproj.path;
if (!mmproj_path.empty()) {
mtmd_helper_log_set(common_log_default_callback, nullptr);
if (!is_resume) {
mtmd_helper_log_set(common_log_default_callback, nullptr);
}
mtmd_context_params mparams = mtmd_context_params_default();
mparams.use_gpu = params_base.mmproj_use_gpu;
@ -699,19 +725,6 @@ struct server_context_impl {
}
}
return true;
}
// initialize slots and server-related data
void init() {
// wiring up server queues
queue_tasks.on_new_task([this](server_task && task) {
process_single_task(std::move(task));
});
queue_tasks.on_update_slots([this]() {
update_slots();
});
// Necessary similarity of prompt for slot selection
slot_prompt_similarity = params_base.slot_prompt_similarity;
@ -726,6 +739,7 @@ struct server_context_impl {
n_ctx_slot = n_ctx_train;
}
slots.clear();
for (int i = 0; i < params_base.n_parallel; i++) {
server_slot slot;
@ -742,13 +756,13 @@ struct server_context_impl {
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
if (slot.ctx_dft == nullptr) {
SRV_ERR("%s", "failed to create draft context\n");
return;
return false;
}
slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft);
if (slot.spec == nullptr) {
SRV_ERR("%s", "failed to create speculator\n");
return;
return false;
}
for (auto & pair : params_base.speculative.replacements) {
common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
@ -782,8 +796,6 @@ struct server_context_impl {
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
}
metrics.init();
if (params_base.cache_ram_mib != 0) {
if (params_base.cache_ram_mib < 0) {
SRV_WRN("prompt cache is enabled, size limit: %s\n", "no limit");
@ -832,6 +844,103 @@ struct server_context_impl {
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
common_chat_templates_source(chat_templates.get()),
common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str());
if (!is_resume) {
return init();
}
return true;
}
// unlike load_model(), this is only called once during initialization
bool init() {
GGML_ASSERT(ctx != nullptr);
GGML_ASSERT(model != nullptr);
GGML_ASSERT(!sleeping);
// wiring up server queues
queue_tasks.on_new_task([this](server_task && task) {
process_single_task(std::move(task));
});
queue_tasks.on_update_slots([this]() {
update_slots();
});
queue_tasks.on_sleeping_state([this](bool sleeping) {
handle_sleeping_state(sleeping);
});
metrics.init();
if (!populate_json_responses()) {
SRV_ERR("%s", "failed to populate JSON responses\n");
return false;
}
return true;
}
bool populate_json_responses() {
// populate webui settings
json json_webui_settings = json::object();
{
if (!params_base.webui_config_json.empty()) {
try {
json_webui_settings = json::parse(params_base.webui_config_json);
} catch (const std::exception & e) {
SRV_ERR("%s: failed to parse webui config: %s\n", __func__, e.what());
return false;
}
}
}
// populate server properties
{
task_params params;
params.sampling = params_base.sampling;
json default_generation_settings_for_props = json {
{"params", params.to_json(true)},
{"n_ctx", get_slot_n_ctx()},
};
json_server_props = {
{ "default_generation_settings", default_generation_settings_for_props },
{ "total_slots", params_base.n_parallel },
{ "model_alias", model_name },
{ "model_path", params_base.model.path },
{ "modalities", json {
{"vision", oai_parser_opt.allow_image},
{"audio", oai_parser_opt.allow_audio},
} },
{ "endpoint_slots", params_base.endpoint_slots },
{ "endpoint_props", params_base.endpoint_props },
{ "endpoint_metrics", params_base.endpoint_metrics },
{ "webui", params_base.webui },
{ "webui_settings", json_webui_settings },
{ "chat_template", common_chat_templates_source(chat_templates.get()) },
{ "bos_token", common_token_to_piece(ctx, llama_vocab_bos(vocab), /* special= */ true)},
{ "eos_token", common_token_to_piece(ctx, llama_vocab_eos(vocab), /* special= */ true)},
{ "build_info", build_info },
};
if (params_base.use_jinja) {
if (auto tool_use_src = common_chat_templates_source(chat_templates.get(), "tool_use")) {
json_server_props["chat_template_tool_use"] = tool_use_src;
}
}
}
// populate model metadata
{
json_server_model_meta = {
{"vocab_type", llama_vocab_type (vocab)},
{"n_vocab", llama_vocab_n_tokens (vocab)},
{"n_ctx_train", llama_model_n_ctx_train(model)},
{"n_embd", llama_model_n_embd (model)},
{"n_params", llama_model_n_params (model)},
{"size", llama_model_size (model)},
};
}
return true;
}
server_slot * get_slot_by_id(int id) {
@ -1993,19 +2102,33 @@ struct server_context_impl {
if (!slot.can_split()) {
if (slot.task->n_tokens() > n_ubatch) {
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
send_error(slot,
string_format(
"input (%d tokens) is too large to process. increase the physical batch "
"size (current batch size: %d)",
slot.task->n_tokens(), n_ubatch),
ERROR_TYPE_SERVER);
slot.release();
continue;
}
if (slot.task->n_tokens() > slot.n_ctx) {
send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
send_error(
slot,
string_format(
"input (%d tokens) is larger than the max context size (%d tokens). skipping",
slot.task->n_tokens(), slot.n_ctx),
ERROR_TYPE_EXCEED_CONTEXT_SIZE);
slot.release();
continue;
}
} else {
if (slot.task->n_tokens() >= slot.n_ctx) {
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
send_error(slot,
string_format("request (%d tokens) exceeds the available context size (%d "
"tokens), try increasing it",
slot.task->n_tokens(), slot.n_ctx),
ERROR_TYPE_EXCEED_CONTEXT_SIZE);
slot.release();
continue;
}
@ -2633,24 +2756,13 @@ struct server_context_impl {
}
}
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) slot.drafted.size(), slot.prompt.n_tokens());
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) n_draft, slot.prompt.n_tokens());
}
}
SRV_DBG("%s", "run slots completed\n");
}
json model_meta() const {
return json {
{"vocab_type", llama_vocab_type (vocab)},
{"n_vocab", llama_vocab_n_tokens (vocab)},
{"n_ctx_train", llama_model_n_ctx_train(model)},
{"n_embd", llama_model_n_embd (model)},
{"n_params", llama_model_n_params (model)},
{"size", llama_model_size (model)},
};
}
int get_slot_n_ctx() {
return slots.back().n_ctx;
}
@ -2667,16 +2779,13 @@ struct server_context_impl {
server_context::server_context() : impl(new server_context_impl()) {}
server_context::~server_context() = default;
void server_context::init() {
impl->init();
}
bool server_context::load_model(const common_params & params) {
return impl->load_model(params);
}
void server_context::start_loop() {
impl->queue_tasks.start_loop();
auto & params = impl->params_base;
impl->queue_tasks.start_loop(params.sleep_idle_seconds * 1000);
}
void server_context::terminate() {
@ -2703,10 +2812,17 @@ server_context_info server_context::get_info() const {
// generator-like API for HTTP response generation
// may have bypass_sleep = true if the task does not use ctx_server
struct server_res_generator : server_http_res {
server_response_reader rd;
server_res_generator(server_context_impl & ctx_server)
: rd(ctx_server.queue_tasks, ctx_server.queue_results, HTTP_POLLING_SECONDS) {}
server_res_generator(server_context_impl & ctx_server, bool bypass_sleep = false)
: rd(ctx_server.queue_tasks, ctx_server.queue_results, HTTP_POLLING_SECONDS) {
// fast path in case sleeping is disabled
bypass_sleep |= ctx_server.params_base.sleep_idle_seconds < 0;
if (!bypass_sleep) {
ctx_server.queue_tasks.wait_until_no_sleep();
}
}
void ok(const json & response_data) {
status = 200;
data = safe_json_to_str(response_data);
@ -2724,6 +2840,7 @@ struct server_res_generator : server_http_res {
//
static std::unique_ptr<server_res_generator> handle_completions_impl(
std::unique_ptr<server_res_generator> && res_ptr,
server_context_impl & ctx_server,
server_task_type type,
const json & data,
@ -2732,7 +2849,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
task_response_type res_type) {
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
auto res = std::make_unique<server_res_generator>(ctx_server);
auto res = std::move(res_ptr);
auto completion_id = gen_chatcmplid();
auto & rd = res->rd;
@ -2936,9 +3053,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
}
void server_routes::init_routes() {
// IMPORTANT: all lambda functions must start with std::make_unique<server_res_generator>
// this is to ensure that the server_res_generator can handle sleeping case correctly
this->get_health = [this](const server_http_req &) {
// error and loading states are handled by middleware
auto res = std::make_unique<server_res_generator>(ctx_server);
auto res = std::make_unique<server_res_generator>(ctx_server, true);
res->ok({{"status", "ok"}});
return res;
};
@ -3120,46 +3240,10 @@ void server_routes::init_routes() {
};
this->get_props = [this](const server_http_req &) {
auto res = std::make_unique<server_res_generator>(ctx_server);
json default_generation_settings_for_props;
{
task_params params;
params.sampling = ctx_server.params_base.sampling;
default_generation_settings_for_props = json {
{"params", params.to_json(true)},
{"n_ctx", ctx_server.get_slot_n_ctx()},
};
}
json data = {
{ "default_generation_settings", default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel },
{ "model_alias", ctx_server.model_name },
{ "model_path", ctx_server.params_base.model.path },
{ "modalities", json {
{"vision", ctx_server.oai_parser_opt.allow_image},
{"audio", ctx_server.oai_parser_opt.allow_audio},
} },
{ "endpoint_slots", params.endpoint_slots },
{ "endpoint_props", params.endpoint_props },
{ "endpoint_metrics", params.endpoint_metrics },
{ "webui", params.webui },
{ "webui_settings", ctx_server.webui_settings },
{ "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
{ "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
{ "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
{ "build_info", build_info },
};
if (ctx_server.params_base.use_jinja) {
if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) {
data["chat_template_tool_use"] = tool_use_src;
}
}
res->ok(data);
auto res = std::make_unique<server_res_generator>(ctx_server, true);
auto props = ctx_server.json_server_props;
props["is_sleeping"] = ctx_server.queue_tasks.is_sleeping();
res->ok(props);
return res;
};
@ -3277,6 +3361,7 @@ void server_routes::init_routes() {
std::vector<raw_buffer> files; // dummy
return handle_completions_impl(
std::move(res),
ctx_server,
SERVER_TASK_TYPE_INFILL,
data,
@ -3286,9 +3371,11 @@ void server_routes::init_routes() {
};
this->post_completions = [this](const server_http_req & req) {
auto res = std::make_unique<server_res_generator>(ctx_server);
std::vector<raw_buffer> files; // dummy
const json body = json::parse(req.body);
return handle_completions_impl(
std::move(res),
ctx_server,
SERVER_TASK_TYPE_COMPLETION,
body,
@ -3298,9 +3385,11 @@ void server_routes::init_routes() {
};
this->post_completions_oai = [this](const server_http_req & req) {
auto res = std::make_unique<server_res_generator>(ctx_server);
std::vector<raw_buffer> files; // dummy
const json body = json::parse(req.body);
return handle_completions_impl(
std::move(res),
ctx_server,
SERVER_TASK_TYPE_COMPLETION,
body,
@ -3310,6 +3399,7 @@ void server_routes::init_routes() {
};
this->post_chat_completions = [this](const server_http_req & req) {
auto res = std::make_unique<server_res_generator>(ctx_server);
std::vector<raw_buffer> files;
json body = json::parse(req.body);
json body_parsed = oaicompat_chat_params_parse(
@ -3317,6 +3407,7 @@ void server_routes::init_routes() {
ctx_server.oai_parser_opt,
files);
return handle_completions_impl(
std::move(res),
ctx_server,
SERVER_TASK_TYPE_COMPLETION,
body_parsed,
@ -3326,6 +3417,7 @@ void server_routes::init_routes() {
};
this->post_anthropic_messages = [this](const server_http_req & req) {
auto res = std::make_unique<server_res_generator>(ctx_server);
std::vector<raw_buffer> files;
json body = convert_anthropic_to_oai(json::parse(req.body));
json body_parsed = oaicompat_chat_params_parse(
@ -3333,6 +3425,7 @@ void server_routes::init_routes() {
ctx_server.oai_parser_opt,
files);
return handle_completions_impl(
std::move(res),
ctx_server,
SERVER_TASK_TYPE_COMPLETION,
body_parsed,
@ -3370,11 +3463,13 @@ void server_routes::init_routes() {
return res;
};
// TODO: this endpoint is unsafe to access during model reloading (i.e. wake up from sleeping)
// how to make it work even during load_model()?
this->get_models = [this](const server_http_req &) {
auto res = std::make_unique<server_res_generator>(ctx_server);
json model_meta = nullptr;
if (is_ready()) {
model_meta = ctx_server.model_meta();
model_meta = ctx_server.json_server_model_meta;
}
bool has_mtmd = ctx_server.mctx != nullptr;
json models = {

View File

@ -22,9 +22,6 @@ struct server_context {
server_context();
~server_context();
// initialize slots and server-related data
void init();
// load the model and initialize llama_context
// returns true on success
bool load_model(const common_params & params);
@ -35,7 +32,7 @@ struct server_context {
// terminate main loop (will unblock start_loop)
void terminate();
// get the underlaying llama_context
// get the underlaying llama_context, can return nullptr if sleeping
llama_context * get_llama_context() const;
// get a new response reader, used by CLI application

View File

@ -82,154 +82,30 @@ static std::filesystem::path get_server_exec_path() {
#endif
}
struct local_model {
std::string name;
std::string path;
std::string path_mmproj;
};
static std::vector<local_model> list_local_models(const std::string & dir) {
if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) {
throw std::runtime_error(string_format("error: '%s' does not exist or is not a directory\n", dir.c_str()));
}
std::vector<local_model> models;
auto scan_subdir = [&models](const std::string & subdir_path, const std::string & name) {
auto files = fs_list(subdir_path, false);
common_file_info model_file;
common_file_info first_shard_file;
common_file_info mmproj_file;
for (const auto & file : files) {
if (string_ends_with(file.name, ".gguf")) {
if (file.name.find("mmproj") != std::string::npos) {
mmproj_file = file;
} else if (file.name.find("-00001-of-") != std::string::npos) {
first_shard_file = file;
} else {
model_file = file;
}
}
}
// single file model
local_model model{
/* name */ name,
/* path */ first_shard_file.path.empty() ? model_file.path : first_shard_file.path,
/* path_mmproj */ mmproj_file.path // can be empty
};
if (!model.path.empty()) {
models.push_back(model);
}
};
auto files = fs_list(dir, true);
for (const auto & file : files) {
if (file.is_dir) {
scan_subdir(file.path, file.name);
} else if (string_ends_with(file.name, ".gguf")) {
// single file model
std::string name = file.name;
string_replace_all(name, ".gguf", "");
local_model model{
/* name */ name,
/* path */ file.path,
/* path_mmproj */ ""
};
models.push_back(model);
}
}
return models;
}
//
// server_presets
//
server_presets::server_presets(int argc, char ** argv, common_params & base_params, const std::string & presets_path)
: ctx_params(common_params_parser_init(base_params, LLAMA_EXAMPLE_SERVER)) {
if (!presets_path.empty()) {
presets = common_presets_load(presets_path, ctx_params);
SRV_INF("Loaded %zu presets from %s\n", presets.size(), presets_path.c_str());
}
// populate reserved args (will be appended by the router)
for (auto & opt : ctx_params.options) {
if (opt.env == nullptr) {
continue;
}
std::string env = opt.env;
if (env == "LLAMA_ARG_PORT" ||
env == "LLAMA_ARG_HOST" ||
env == "LLAMA_ARG_ALIAS" ||
env == "LLAMA_ARG_API_KEY" ||
env == "LLAMA_ARG_MODELS_DIR" ||
env == "LLAMA_ARG_MODELS_MAX" ||
env == "LLAMA_ARG_MODELS_PRESET" ||
env == "LLAMA_ARG_MODEL" ||
env == "LLAMA_ARG_MMPROJ" ||
env == "LLAMA_ARG_HF_REPO" ||
env == "LLAMA_ARG_NO_MODELS_AUTOLOAD" ||
env == "LLAMA_ARG_SSL_KEY_FILE" ||
env == "LLAMA_ARG_SSL_CERT_FILE") {
control_args[env] = opt;
}
}
// read base args from router's argv
common_params_to_map(argc, argv, LLAMA_EXAMPLE_SERVER, base_args);
// remove any router-controlled args from base_args
for (const auto & cargs : control_args) {
auto it = base_args.find(cargs.second);
if (it != base_args.end()) {
base_args.erase(it);
}
static void unset_reserved_args(common_preset & preset, bool unset_model_args) {
preset.unset_option("LLAMA_ARG_SSL_KEY_FILE");
preset.unset_option("LLAMA_ARG_SSL_CERT_FILE");
preset.unset_option("LLAMA_API_KEY");
preset.unset_option("LLAMA_ARG_MODELS_DIR");
preset.unset_option("LLAMA_ARG_MODELS_MAX");
preset.unset_option("LLAMA_ARG_MODELS_PRESET");
preset.unset_option("LLAMA_ARG_MODELS_AUTOLOAD");
if (unset_model_args) {
preset.unset_option("LLAMA_ARG_MODEL");
preset.unset_option("LLAMA_ARG_MMPROJ");
preset.unset_option("LLAMA_ARG_HF_REPO");
}
}
common_preset server_presets::get_preset(const std::string & name) {
auto it = presets.find(name);
if (it != presets.end()) {
return it->second;
}
return common_preset();
}
void server_presets::render_args(server_model_meta & meta) {
common_preset preset = meta.preset; // copy
// merging 3 kinds of args:
// 1. model-specific args (from preset)
// force removing control args if any
for (auto & cargs : control_args) {
if (preset.options.find(cargs.second) != preset.options.end()) {
SRV_WRN("Preset '%s' contains reserved arg '%s', removing it\n", preset.name.c_str(), cargs.second.args[0]);
preset.options.erase(cargs.second);
}
}
// 2. base args (from router)
// inherit from base args
for (const auto & [arg, value] : base_args) {
preset.options[arg] = value;
}
// 3. control args (from router)
// set control values
preset.options[control_args["LLAMA_ARG_HOST"]] = CHILD_ADDR;
preset.options[control_args["LLAMA_ARG_PORT"]] = std::to_string(meta.port);
preset.options[control_args["LLAMA_ARG_ALIAS"]] = meta.name;
if (meta.in_cache) {
preset.options[control_args["LLAMA_ARG_HF_REPO"]] = meta.name;
} else {
preset.options[control_args["LLAMA_ARG_MODEL"]] = meta.path;
if (!meta.path_mmproj.empty()) {
preset.options[control_args["LLAMA_ARG_MMPROJ"]] = meta.path_mmproj;
}
}
// disable SSL for child processes (HTTPS already handled by router)
preset.options[control_args["LLAMA_ARG_SSL_KEY_FILE"]] = "";
preset.options[control_args["LLAMA_ARG_SSL_CERT_FILE"]] = "";
meta.args = preset.to_args();
// add back the binary path at the front
meta.args.insert(meta.args.begin(), get_server_exec_path().string());
void server_model_meta::update_args(common_preset_context & ctx_preset, std::string bin_path) {
// update params
unset_reserved_args(preset, false);
preset.set_option(ctx_preset, "LLAMA_ARG_HOST", CHILD_ADDR);
preset.set_option(ctx_preset, "LLAMA_ARG_PORT", std::to_string(port));
preset.set_option(ctx_preset, "LLAMA_ARG_ALIAS", name);
// TODO: maybe validate preset before rendering ?
// render args
args = preset.to_args(bin_path);
}
//
@ -240,20 +116,22 @@ server_models::server_models(
const common_params & params,
int argc,
char ** argv,
char ** envp) : base_params(params), presets(argc, argv, base_params, params.models_preset) {
for (int i = 0; i < argc; i++) {
base_args.push_back(std::string(argv[i]));
}
char ** envp)
: ctx_preset(LLAMA_EXAMPLE_SERVER),
base_params(params),
base_preset(ctx_preset.load_from_args(argc, argv)) {
for (char ** env = envp; *env != nullptr; env++) {
base_env.push_back(std::string(*env));
}
GGML_ASSERT(!base_args.empty());
// clean up base preset
unset_reserved_args(base_preset, true);
// set binary path
try {
base_args[0] = get_server_exec_path().string();
bin_path = get_server_exec_path().string();
} catch (const std::exception & e) {
bin_path = argv[0];
LOG_WRN("failed to get server executable path: %s\n", e.what());
LOG_WRN("using original argv[0] as fallback: %s\n", base_args[0].c_str());
LOG_WRN("using original argv[0] as fallback: %s\n", argv[0]);
}
load_models();
}
@ -262,7 +140,7 @@ void server_models::add_model(server_model_meta && meta) {
if (mapping.find(meta.name) != mapping.end()) {
throw std::runtime_error(string_format("model '%s' appears multiple times", meta.name.c_str()));
}
presets.render_args(meta); // populate meta.args
meta.update_args(ctx_preset, bin_path); // render args
std::string name = meta.name;
mapping[name] = instance_t{
/* subproc */ std::make_shared<subprocess_s>(),
@ -271,86 +149,62 @@ void server_models::add_model(server_model_meta && meta) {
};
}
static std::vector<local_model> list_custom_path_models(server_presets & presets) {
// detect any custom-path models in presets
std::vector<local_model> custom_models;
for (auto & [model_name, preset] : presets.presets) {
local_model model;
model.name = model_name;
std::vector<common_arg> to_erase;
for (auto & [arg, value] : preset.options) {
std::string env(arg.env ? arg.env : "");
if (env == "LLAMA_ARG_MODEL") {
model.path = value;
to_erase.push_back(arg);
}
if (env == "LLAMA_ARG_MMPROJ") {
model.path_mmproj = value;
to_erase.push_back(arg);
}
}
for (auto & arg : to_erase) {
preset.options.erase(arg);
}
if (!model.name.empty() && !model.path.empty()) {
custom_models.push_back(model);
}
}
return custom_models;
}
// TODO: allow refreshing cached model list
void server_models::load_models() {
// loading models from 3 sources:
// 1. cached models
auto cached_models = common_list_cached_models();
for (const auto & model : cached_models) {
server_model_meta meta{
/* preset */ presets.get_preset(model.to_string()),
/* name */ model.to_string(),
/* path */ model.manifest_path,
/* path_mmproj */ "", // auto-detected when loading
/* in_cache */ true,
/* port */ 0,
/* status */ SERVER_MODEL_STATUS_UNLOADED,
/* last_used */ 0,
/* args */ std::vector<std::string>(),
/* exit_code */ 0
};
add_model(std::move(meta));
}
// 2. local models specificed via --models-dir
common_presets cached_models = ctx_preset.load_from_cache();
SRV_INF("Loaded %zu cached model presets\n", cached_models.size());
// 2. local models from --models-dir
common_presets local_models;
if (!base_params.models_dir.empty()) {
auto local_models = list_local_models(base_params.models_dir);
for (const auto & model : local_models) {
if (mapping.find(model.name) != mapping.end()) {
// already exists in cached models, skip
continue;
}
server_model_meta meta{
/* preset */ presets.get_preset(model.name),
/* name */ model.name,
/* path */ model.path,
/* path_mmproj */ model.path_mmproj,
/* in_cache */ false,
/* port */ 0,
/* status */ SERVER_MODEL_STATUS_UNLOADED,
/* last_used */ 0,
/* args */ std::vector<std::string>(),
/* exit_code */ 0
};
add_model(std::move(meta));
local_models = ctx_preset.load_from_models_dir(base_params.models_dir);
SRV_INF("Loaded %zu local model presets from %s\n", local_models.size(), base_params.models_dir.c_str());
}
// 3. custom-path models from presets
common_preset global = {};
common_presets custom_presets = {};
if (!base_params.models_preset.empty()) {
custom_presets = ctx_preset.load_from_ini(base_params.models_preset, global);
SRV_INF("Loaded %zu custom model presets from %s\n", custom_presets.size(), base_params.models_preset.c_str());
}
// cascade, apply global preset first
cached_models = ctx_preset.cascade(global, cached_models);
local_models = ctx_preset.cascade(global, local_models);
custom_presets = ctx_preset.cascade(global, custom_presets);
// note: if a model exists in both cached and local, local takes precedence
common_presets final_presets;
for (const auto & [name, preset] : cached_models) {
final_presets[name] = preset;
}
for (const auto & [name, preset] : local_models) {
final_presets[name] = preset;
}
// process custom presets from INI
for (const auto & [name, custom] : custom_presets) {
if (final_presets.find(name) != final_presets.end()) {
// apply custom config if exists
common_preset & target = final_presets[name];
target.merge(custom);
} else {
// otherwise add directly
final_presets[name] = custom;
}
}
// 3. custom-path models specified in presets
auto custom_models = list_custom_path_models(presets);
for (const auto & model : custom_models) {
// server base preset from CLI args take highest precedence
for (auto & [name, preset] : final_presets) {
preset.merge(base_preset);
}
// convert presets to server_model_meta and add to mapping
for (const auto & preset : final_presets) {
server_model_meta meta{
/* preset */ presets.get_preset(model.name),
/* name */ model.name,
/* path */ model.path,
/* path_mmproj */ model.path_mmproj,
/* in_cache */ false,
/* preset */ preset.second,
/* name */ preset.first,
/* port */ 0,
/* status */ SERVER_MODEL_STATUS_UNLOADED,
/* last_used */ 0,
@ -359,10 +213,38 @@ void server_models::load_models() {
};
add_model(std::move(meta));
}
// log available models
SRV_INF("Available models (%zu) (*: custom preset)\n", mapping.size());
{
std::unordered_set<std::string> custom_names;
for (const auto & [name, preset] : custom_presets) {
custom_names.insert(name);
}
SRV_INF("Available models (%zu) (*: custom preset)\n", mapping.size());
for (const auto & [name, inst] : mapping) {
bool has_custom = custom_names.find(name) != custom_names.end();
SRV_INF(" %c %s\n", has_custom ? '*' : ' ', name.c_str());
}
}
// load any autoload models
std::vector<std::string> models_to_load;
for (const auto & [name, inst] : mapping) {
SRV_INF(" %c %s\n", inst.meta.preset.name.empty() ? ' ' : '*', name.c_str());
std::string val;
if (inst.meta.preset.get_option(COMMON_ARG_PRESET_LOAD_ON_STARTUP, val)) {
models_to_load.push_back(name);
}
}
if ((int)models_to_load.size() > base_params.models_max) {
throw std::runtime_error(string_format(
"number of models to load on startup (%zu) exceeds models_max (%d)",
models_to_load.size(),
base_params.models_max
));
}
for (const auto & name : models_to_load) {
SRV_INF("(startup) loading model %s\n", name.c_str());
load(name);
}
}
@ -526,7 +408,7 @@ void server_models::load(const std::string & name) {
{
SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port);
presets.render_args(inst.meta); // update meta.args
inst.meta.update_args(ctx_preset, bin_path); // render args
std::vector<std::string> child_args = inst.meta.args; // copy
std::vector<std::string> child_env = base_env; // copy
@ -877,7 +759,12 @@ void server_models_routes::init_routes() {
{"args", meta.args},
};
if (!meta.preset.name.empty()) {
status["preset"] = meta.preset.to_ini();
common_preset preset_copy = meta.preset;
unset_reserved_args(preset_copy, false);
preset_copy.unset_option("LLAMA_ARG_HOST");
preset_copy.unset_option("LLAMA_ARG_PORT");
preset_copy.unset_option("LLAMA_ARG_ALIAS");
status["preset"] = preset_copy.to_ini();
}
if (meta.is_failed()) {
status["exit_code"] = meta.exit_code;
@ -888,8 +775,6 @@ void server_models_routes::init_routes() {
{"object", "model"}, // for OAI-compat
{"owned_by", "llamacpp"}, // for OAI-compat
{"created", t}, // for OAI-compat
{"in_cache", meta.in_cache},
{"path", meta.path},
{"status", status},
// TODO: add other fields, may require reading GGUF metadata
});

View File

@ -51,9 +51,6 @@ static std::string server_model_status_to_string(server_model_status status) {
struct server_model_meta {
common_preset preset;
std::string name;
std::string path;
std::string path_mmproj; // only available if in_cache=false
bool in_cache = false; // if true, use -hf; use -m otherwise
int port = 0;
server_model_status status = SERVER_MODEL_STATUS_UNLOADED;
int64_t last_used = 0; // for LRU unloading
@ -67,19 +64,8 @@ struct server_model_meta {
bool is_failed() const {
return status == SERVER_MODEL_STATUS_UNLOADED && exit_code != 0;
}
};
// the server_presets struct holds the presets read from presets.ini
// as well as base args from the router server
struct server_presets {
common_presets presets;
common_params_context ctx_params;
std::map<common_arg, std::string> base_args;
std::map<std::string, common_arg> control_args; // args reserved for server control
server_presets(int argc, char ** argv, common_params & base_params, const std::string & models_dir);
common_preset get_preset(const std::string & name);
void render_args(server_model_meta & meta);
void update_args(common_preset_context & ctx_presets, std::string bin_path);
};
struct subprocess_s;
@ -97,11 +83,12 @@ private:
std::condition_variable cv;
std::map<std::string, instance_t> mapping;
common_params base_params;
std::vector<std::string> base_args;
std::vector<std::string> base_env;
common_preset_context ctx_preset;
server_presets presets;
common_params base_params;
std::string bin_path;
std::vector<std::string> base_env;
common_preset base_preset; // base preset from llama-server CLI args
void update_meta(const std::string & name, const server_model_meta & meta);
@ -116,27 +103,29 @@ public:
void load_models();
// check if a model instance exists
// check if a model instance exists (thread-safe)
bool has_model(const std::string & name);
// return a copy of model metadata
// return a copy of model metadata (thread-safe)
std::optional<server_model_meta> get_meta(const std::string & name);
// return a copy of all model metadata
// return a copy of all model metadata (thread-safe)
std::vector<server_model_meta> get_all_meta();
// load and unload model instances
// these functions are thread-safe
void load(const std::string & name);
void unload(const std::string & name);
void unload_all();
// update the status of a model instance
// update the status of a model instance (thread-safe)
void update_status(const std::string & name, server_model_status status);
// wait until the model instance is fully loaded
// wait until the model instance is fully loaded (thread-safe)
// return when the model is loaded or failed to load
void wait_until_loaded(const std::string & name);
// load the model if not loaded, otherwise do nothing
// load the model if not loaded, otherwise do nothing (thread-safe)
// return false if model is already loaded; return true otherwise (meta may need to be refreshed)
bool ensure_model_loaded(const std::string & name);

View File

@ -33,6 +33,7 @@ int server_queue::post(server_task && task, bool front) {
} else {
queue_tasks.push_back(std::move(task));
}
time_last_task = ggml_time_ms();
condition_tasks.notify_one();
return task_id;
}
@ -54,6 +55,7 @@ int server_queue::post(std::vector<server_task> && tasks, bool front) {
queue_tasks.push_back(std::move(task));
}
}
time_last_task = ggml_time_ms();
condition_tasks.notify_one();
return 0;
}
@ -62,6 +64,7 @@ void server_queue::defer(server_task && task) {
std::unique_lock<std::mutex> lock(mutex_tasks);
QUE_DBG("defer task, id = %d\n", task.id);
queue_tasks_deferred.push_back(std::move(task));
time_last_task = ggml_time_ms();
condition_tasks.notify_one();
}
@ -71,31 +74,52 @@ int server_queue::get_new_id() {
return new_id;
}
void server_queue::on_new_task(std::function<void(server_task &&)> callback) {
callback_new_task = std::move(callback);
}
void server_queue::on_update_slots(std::function<void(void)> callback) {
callback_update_slots = std::move(callback);
}
void server_queue::pop_deferred_task() {
std::unique_lock<std::mutex> lock(mutex_tasks);
if (!queue_tasks_deferred.empty()) {
queue_tasks.emplace_front(std::move(queue_tasks_deferred.front()));
queue_tasks_deferred.pop_front();
}
time_last_task = ggml_time_ms();
condition_tasks.notify_one();
}
void server_queue::wait_until_no_sleep() {
std::unique_lock<std::mutex> lock(mutex_tasks);
if (!sleeping) {
return;
} else {
if (!req_stop_sleeping) {
QUE_DBG("%s", "requesting to stop sleeping\n");
req_stop_sleeping = true;
condition_tasks.notify_one(); // only main thread is waiting on this
}
QUE_DBG("%s", "waiting until no sleep\n");
condition_tasks.wait(lock, [&]{
return !sleeping;
});
}
}
void server_queue::terminate() {
std::unique_lock<std::mutex> lock(mutex_tasks);
running = false;
condition_tasks.notify_all();
}
void server_queue::start_loop() {
void server_queue::start_loop(int64_t idle_sleep_ms) {
running = true;
time_last_task = ggml_time_ms();
constexpr auto max_wait_time = std::chrono::seconds(1);
auto should_sleep = [&]() -> bool {
// caller must hold mutex_tasks
if (idle_sleep_ms < 0) {
return false;
}
int64_t now = ggml_time_ms();
return (now - time_last_task) >= idle_sleep_ms;
};
while (true) {
QUE_DBG("%s", "processing new tasks\n");
@ -117,23 +141,53 @@ void server_queue::start_loop() {
QUE_DBG("processing task, id = %d\n", task.id);
callback_new_task(std::move(task));
}
// all tasks in the current loop is processed, slots data is now ready
QUE_DBG("%s", "update slots\n");
// this will run the main inference process for all slots
callback_update_slots();
{
// update_slots() may take a while to finish, we need to make sure it's not counted as idle
std::unique_lock<std::mutex> lock(mutex_tasks);
time_last_task = ggml_time_ms();
}
QUE_DBG("%s", "waiting for new tasks\n");
{
while (true) {
std::unique_lock<std::mutex> lock(mutex_tasks);
if (!running) {
QUE_DBG("%s", "terminate\n");
return;
if (!running || !queue_tasks.empty()) {
break; // go back to process new tasks or terminate
}
if (queue_tasks.empty()) {
// no tasks, check for sleeping state
if (should_sleep()) {
QUE_INF("%s", "entering sleeping state\n");
sleeping = true;
callback_sleeping_state(true);
req_stop_sleeping = false;
// wait until we are requested to exit sleeping state
condition_tasks.wait(lock, [&]{
return (!running || req_stop_sleeping);
});
if (!running) { // may changed during sleep
break; // terminate
}
QUE_INF("%s", "exiting sleeping state\n");
req_stop_sleeping = false;
callback_sleeping_state(false);
sleeping = false;
time_last_task = ggml_time_ms();
condition_tasks.notify_all(); // notify wait_until_no_sleep()
break; // process new tasks
} else {
// wait for new tasks or timeout for checking sleeping condition
bool res = condition_tasks.wait_for(lock, max_wait_time, [&]{
return (!queue_tasks.empty() || !running);
});
if (res) {
break; // new task arrived or terminate
}
// otherwise, loop again to check sleeping condition
}
}
}

View File

@ -12,7 +12,10 @@
struct server_queue {
private:
int id = 0;
bool running;
bool running = false;
bool sleeping = false;
bool req_stop_sleeping = false;
int64_t time_last_task = 0;
// queues
std::deque<server_task> queue_tasks;
@ -24,6 +27,7 @@ private:
// callback functions
std::function<void(server_task &&)> callback_new_task;
std::function<void(void)> callback_update_slots;
std::function<void(bool)> callback_sleeping_state;
public:
// Add a new task to the end of the queue
@ -38,15 +42,18 @@ public:
// Get the next id for creating a new task
int get_new_id();
// Register function to process a new task
void on_new_task(std::function<void(server_task &&)> callback);
// Register the function to be called when all slots data is ready to be processed
void on_update_slots(std::function<void(void)> callback);
// Call when the state of one slot is changed, it will move one task from deferred to main queue
void pop_deferred_task();
// if sleeping, request exiting sleep state and wait until it is done
// returns immediately if not sleeping
void wait_until_no_sleep();
bool is_sleeping() {
std::unique_lock<std::mutex> lock(mutex_tasks);
return sleeping;
}
// end the start_loop routine
void terminate();
@ -56,8 +63,15 @@ public:
* - Process the task (i.e. maybe copy data into slot)
* - Check if multitask is finished
* - Update all slots
*
* Sleeping procedure (disabled if idle_sleep_ms < 0):
* - If there is no task after idle_sleep_ms, enter sleeping state
* - Call callback_sleeping_state(true)
* - Wait until req_stop_sleeping is set to true
* - Call callback_sleeping_state(false)
* - Exit sleeping state
*/
void start_loop();
void start_loop(int64_t idle_sleep_ms = -1);
// for metrics
size_t queue_tasks_deferred_size() {
@ -65,6 +79,27 @@ public:
return queue_tasks_deferred.size();
}
//
// Functions below are not thread-safe, must only be used before start_loop() is called
//
// Register function to process a new task
void on_new_task(std::function<void(server_task &&)> callback) {
callback_new_task = std::move(callback);
}
// Register the function to be called when all slots data is ready to be processed
void on_update_slots(std::function<void(void)> callback) {
callback_update_slots = std::move(callback);
}
// Register callback for sleeping state change
// note: when entering sleeping state, the callback is called AFTER sleeping is set to true
// when leaving sleeping state, the callback is called BEFORE sleeping is set to false
void on_sleeping_state(std::function<void(bool)> callback) {
callback_sleeping_state = std::move(callback);
}
private:
void cleanup_pending_task(int id_target);
};

View File

@ -252,7 +252,6 @@ int main(int argc, char ** argv, char ** envp) {
return 1;
}
ctx_server.init();
ctx_http.is_ready.store(true);
LOG_INF("%s: model loaded\n", __func__);
@ -309,7 +308,11 @@ int main(int argc, char ** argv, char ** envp) {
if (monitor_thread.joinable()) {
monitor_thread.join();
}
llama_memory_breakdown_print(ctx_server.get_llama_context());
auto * ll_ctx = ctx_server.get_llama_context();
if (ll_ctx != nullptr) {
llama_memory_breakdown_print(ll_ctx);
}
}
return 0;

View File

@ -0,0 +1,39 @@
import pytest
import time
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
def test_server_sleep():
global server
server.sleep_idle_seconds = 1
server.start()
# wait a bit so that server can go to sleep
time.sleep(2)
# make sure these endpoints are still responsive after sleep
res = server.make_request("GET", "/health")
assert res.status_code == 200
res = server.make_request("GET", "/props")
assert res.status_code == 200
assert res.body["is_sleeping"] == True
# make a generation request to wake up the server
res = server.make_request("POST", "/completion", data={
"n_predict": 1,
"prompt": "Hello",
})
assert res.status_code == 200
# it should no longer be sleeping
res = server.make_request("GET", "/props")
assert res.status_code == 200
assert res.body["is_sleeping"] == False

View File

@ -100,6 +100,7 @@ class ServerProcess:
server_path: str | None = None
mmproj_url: str | None = None
media_path: str | None = None
sleep_idle_seconds: int | None = None
# session variables
process: subprocess.Popen | None = None
@ -230,6 +231,8 @@ class ServerProcess:
server_args.extend(["--mmproj-url", self.mmproj_url])
if self.media_path:
server_args.extend(["--media-path", self.media_path])
if self.sleep_idle_seconds is not None:
server_args.extend(["--sleep-idle-seconds", self.sleep_idle_seconds])
args = [str(arg) for arg in [server_path, *server_args]]
print(f"tests: starting server with: {' '.join(args)}")

View File

@ -11,6 +11,8 @@ flowchart TB
C_Screen["ChatScreen"]
C_Form["ChatForm"]
C_Messages["ChatMessages"]
C_Message["ChatMessage"]
C_MessageEditForm["ChatMessageEditForm"]
C_ModelsSelector["ModelsSelector"]
C_Settings["ChatSettings"]
end
@ -54,7 +56,9 @@ flowchart TB
%% Component hierarchy
C_Screen --> C_Form & C_Messages & C_Settings
C_Form & C_Messages --> C_ModelsSelector
C_Messages --> C_Message
C_Message --> C_MessageEditForm
C_Form & C_MessageEditForm --> C_ModelsSelector
%% Components → Hooks → Stores
C_Form & C_Messages --> H1 & H2
@ -93,7 +97,7 @@ flowchart TB
classDef apiStyle fill:#e3f2fd,stroke:#1565c0,stroke-width:2px
class R1,R2,RL routeStyle
class C_Sidebar,C_Screen,C_Form,C_Messages,C_ModelsSelector,C_Settings componentStyle
class C_Sidebar,C_Screen,C_Form,C_Messages,C_Message,C_MessageEditForm,C_ModelsSelector,C_Settings componentStyle
class H1,H2 hookStyle
class S1,S2,S3,S4,S5 storeStyle
class SV1,SV2,SV3,SV4,SV5 serviceStyle

View File

@ -16,6 +16,8 @@ end
C_Form["ChatForm"]
C_Messages["ChatMessages"]
C_Message["ChatMessage"]
C_MessageUser["ChatMessageUser"]
C_MessageEditForm["ChatMessageEditForm"]
C_Attach["ChatAttachments"]
C_ModelsSelector["ModelsSelector"]
C_Settings["ChatSettings"]
@ -38,7 +40,7 @@ end
S1Error["<b>Error Handling:</b><br/>showErrorDialog()<br/>dismissErrorDialog()<br/>isAbortError()"]
S1Msg["<b>Message Operations:</b><br/>addMessage()<br/>sendMessage()<br/>updateMessage()<br/>deleteMessage()<br/>getDeletionInfo()"]
S1Regen["<b>Regeneration:</b><br/>regenerateMessage()<br/>regenerateMessageWithBranching()<br/>continueAssistantMessage()"]
S1Edit["<b>Editing:</b><br/>editAssistantMessage()<br/>editUserMessagePreserveResponses()<br/>editMessageWithBranching()"]
S1Edit["<b>Editing:</b><br/>editAssistantMessage()<br/>editUserMessagePreserveResponses()<br/>editMessageWithBranching()<br/>clearEditMode()<br/>isEditModeActive()<br/>getAddFilesHandler()<br/>setEditModeActive()"]
S1Utils["<b>Utilities:</b><br/>getApiOptions()<br/>parseTimingData()<br/>getOrCreateAbortController()<br/>getConversationModel()"]
end
subgraph S2["conversationsStore"]
@ -88,6 +90,10 @@ end
RE7["getChatStreaming()"]
RE8["getAllLoadingChats()"]
RE9["getAllStreamingChats()"]
RE9a["isEditModeActive()"]
RE9b["getAddFilesHandler()"]
RE9c["setEditModeActive()"]
RE9d["clearEditMode()"]
end
subgraph ConvExports["conversationsStore"]
RE10["conversations()"]
@ -182,7 +188,10 @@ end
%% Component hierarchy
C_Screen --> C_Form & C_Messages & C_Settings
C_Messages --> C_Message
C_Message --> C_ModelsSelector
C_Message --> C_MessageUser
C_MessageUser --> C_MessageEditForm
C_MessageEditForm --> C_ModelsSelector
C_MessageEditForm --> C_Attach
C_Form --> C_ModelsSelector
C_Form --> C_Attach
C_Message --> C_Attach
@ -190,6 +199,7 @@ end
%% Components use Hooks
C_Form --> H1
C_Message --> H1 & H2
C_MessageEditForm --> H1
C_Screen --> H2
%% Hooks use Stores
@ -244,7 +254,7 @@ end
classDef apiStyle fill:#e3f2fd,stroke:#1565c0,stroke-width:2px
class R1,R2,RL routeStyle
class C_Sidebar,C_Screen,C_Form,C_Messages,C_Message componentStyle
class C_Sidebar,C_Screen,C_Form,C_Messages,C_Message,C_MessageUser,C_MessageEditForm componentStyle
class C_ModelsSelector,C_Settings componentStyle
class C_Attach componentStyle
class H1,H2,H3 methodStyle

View File

@ -25,7 +25,7 @@
"@chromatic-com/storybook": "^4.1.2",
"@eslint/compat": "^1.2.5",
"@eslint/js": "^9.18.0",
"@internationalized/date": "^3.8.2",
"@internationalized/date": "^3.10.1",
"@lucide/svelte": "^0.515.0",
"@playwright/test": "^1.49.1",
"@storybook/addon-a11y": "^10.0.7",
@ -862,9 +862,9 @@
}
},
"node_modules/@internationalized/date": {
"version": "3.8.2",
"resolved": "https://registry.npmjs.org/@internationalized/date/-/date-3.8.2.tgz",
"integrity": "sha512-/wENk7CbvLbkUvX1tu0mwq49CVkkWpkXubGel6birjRPyo6uQ4nQpnq5xZu823zRCwwn82zgHrvgF1vZyvmVgA==",
"version": "3.10.1",
"resolved": "https://registry.npmjs.org/@internationalized/date/-/date-3.10.1.tgz",
"integrity": "sha512-oJrXtQiAXLvT9clCf1K4kxp3eKsQhIaZqxEyowkBcsvZDdZkbWrVmnGknxs5flTD0VGsxrxKgBCZty1EzoiMzA==",
"dev": true,
"license": "Apache-2.0",
"dependencies": {

View File

@ -26,7 +26,7 @@
"@chromatic-com/storybook": "^4.1.2",
"@eslint/compat": "^1.2.5",
"@eslint/js": "^9.18.0",
"@internationalized/date": "^3.8.2",
"@internationalized/date": "^3.10.1",
"@lucide/svelte": "^0.515.0",
"@playwright/test": "^1.49.1",
"@storybook/addon-a11y": "^10.0.7",

View File

@ -8,6 +8,7 @@
ChatFormTextarea
} from '$lib/components/app';
import { INPUT_CLASSES } from '$lib/constants/input-classes';
import { SETTING_CONFIG_DEFAULT } from '$lib/constants/settings-config';
import { config } from '$lib/stores/settings.svelte';
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
import { isRouterMode } from '$lib/stores/server.svelte';
@ -66,7 +67,7 @@
let message = $state('');
let pasteLongTextToFileLength = $derived.by(() => {
const n = Number(currentConfig.pasteLongTextToFileLen);
return Number.isNaN(n) ? 2500 : n;
return Number.isNaN(n) ? Number(SETTING_CONFIG_DEFAULT.pasteLongTextToFileLen) : n;
});
let previousIsLoading = $state(isLoading);
let recordingSupported = $state(false);

View File

@ -12,13 +12,21 @@
onCopy?: (message: DatabaseMessage) => void;
onContinueAssistantMessage?: (message: DatabaseMessage) => void;
onDelete?: (message: DatabaseMessage) => void;
onEditWithBranching?: (message: DatabaseMessage, newContent: string) => void;
onEditWithBranching?: (
message: DatabaseMessage,
newContent: string,
newExtras?: DatabaseMessageExtra[]
) => void;
onEditWithReplacement?: (
message: DatabaseMessage,
newContent: string,
shouldBranch: boolean
) => void;
onEditUserMessagePreserveResponses?: (message: DatabaseMessage, newContent: string) => void;
onEditUserMessagePreserveResponses?: (
message: DatabaseMessage,
newContent: string,
newExtras?: DatabaseMessageExtra[]
) => void;
onNavigateToSibling?: (siblingId: string) => void;
onRegenerateWithBranching?: (message: DatabaseMessage, modelOverride?: string) => void;
siblingInfo?: ChatMessageSiblingInfo | null;
@ -45,6 +53,8 @@
messageTypes: string[];
} | null>(null);
let editedContent = $state(message.content);
let editedExtras = $state<DatabaseMessageExtra[]>(message.extra ? [...message.extra] : []);
let editedUploadedFiles = $state<ChatUploadedFile[]>([]);
let isEditing = $state(false);
let showDeleteDialog = $state(false);
let shouldBranchAfterEdit = $state(false);
@ -85,6 +95,16 @@
function handleCancelEdit() {
isEditing = false;
editedContent = message.content;
editedExtras = message.extra ? [...message.extra] : [];
editedUploadedFiles = [];
}
function handleEditedExtrasChange(extras: DatabaseMessageExtra[]) {
editedExtras = extras;
}
function handleEditedUploadedFilesChange(files: ChatUploadedFile[]) {
editedUploadedFiles = files;
}
async function handleCopy() {
@ -107,6 +127,8 @@
function handleEdit() {
isEditing = true;
editedContent = message.content;
editedExtras = message.extra ? [...message.extra] : [];
editedUploadedFiles = [];
setTimeout(() => {
if (textareaElement) {
@ -143,9 +165,10 @@
onContinueAssistantMessage?.(message);
}
function handleSaveEdit() {
async function handleSaveEdit() {
if (message.role === 'user' || message.role === 'system') {
onEditWithBranching?.(message, editedContent.trim());
const finalExtras = await getMergedExtras();
onEditWithBranching?.(message, editedContent.trim(), finalExtras);
} else {
// For assistant messages, preserve exact content including trailing whitespace
// This is important for the Continue feature to work properly
@ -154,15 +177,30 @@
isEditing = false;
shouldBranchAfterEdit = false;
editedUploadedFiles = [];
}
function handleSaveEditOnly() {
async function handleSaveEditOnly() {
if (message.role === 'user') {
// For user messages, trim to avoid accidental whitespace
onEditUserMessagePreserveResponses?.(message, editedContent.trim());
const finalExtras = await getMergedExtras();
onEditUserMessagePreserveResponses?.(message, editedContent.trim(), finalExtras);
}
isEditing = false;
editedUploadedFiles = [];
}
async function getMergedExtras(): Promise<DatabaseMessageExtra[]> {
if (editedUploadedFiles.length === 0) {
return editedExtras;
}
const { parseFilesToMessageExtras } = await import('$lib/utils/browser-only');
const result = await parseFilesToMessageExtras(editedUploadedFiles);
const newExtras = result?.extras || [];
return [...editedExtras, ...newExtras];
}
function handleShowDeleteDialogChange(show: boolean) {
@ -197,6 +235,8 @@
class={className}
{deletionInfo}
{editedContent}
{editedExtras}
{editedUploadedFiles}
{isEditing}
{message}
onCancelEdit={handleCancelEdit}
@ -206,6 +246,8 @@
onEdit={handleEdit}
onEditKeydown={handleEditKeydown}
onEditedContentChange={handleEditedContentChange}
onEditedExtrasChange={handleEditedExtrasChange}
onEditedUploadedFilesChange={handleEditedUploadedFilesChange}
{onNavigateToSibling}
onSaveEdit={handleSaveEdit}
onSaveEditOnly={handleSaveEditOnly}

View File

@ -0,0 +1,391 @@
<script lang="ts">
import { X, ArrowUp, Paperclip, AlertTriangle } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import { Switch } from '$lib/components/ui/switch';
import { ChatAttachmentsList, DialogConfirmation, ModelsSelector } from '$lib/components/app';
import { INPUT_CLASSES } from '$lib/constants/input-classes';
import { SETTING_CONFIG_DEFAULT } from '$lib/constants/settings-config';
import { AttachmentType, FileTypeCategory, MimeTypeText } from '$lib/enums';
import { config } from '$lib/stores/settings.svelte';
import { useModelChangeValidation } from '$lib/hooks/use-model-change-validation.svelte';
import { setEditModeActive, clearEditMode } from '$lib/stores/chat.svelte';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { modelsStore } from '$lib/stores/models.svelte';
import { isRouterMode } from '$lib/stores/server.svelte';
import {
autoResizeTextarea,
getFileTypeCategory,
getFileTypeCategoryByExtension,
parseClipboardContent
} from '$lib/utils';
interface Props {
messageId: string;
editedContent: string;
editedExtras?: DatabaseMessageExtra[];
editedUploadedFiles?: ChatUploadedFile[];
originalContent: string;
originalExtras?: DatabaseMessageExtra[];
showSaveOnlyOption?: boolean;
onCancelEdit: () => void;
onSaveEdit: () => void;
onSaveEditOnly?: () => void;
onEditKeydown: (event: KeyboardEvent) => void;
onEditedContentChange: (content: string) => void;
onEditedExtrasChange?: (extras: DatabaseMessageExtra[]) => void;
onEditedUploadedFilesChange?: (files: ChatUploadedFile[]) => void;
textareaElement?: HTMLTextAreaElement;
}
let {
messageId,
editedContent,
editedExtras = [],
editedUploadedFiles = [],
originalContent,
originalExtras = [],
showSaveOnlyOption = false,
onCancelEdit,
onSaveEdit,
onSaveEditOnly,
onEditKeydown,
onEditedContentChange,
onEditedExtrasChange,
onEditedUploadedFilesChange,
textareaElement = $bindable()
}: Props = $props();
let fileInputElement: HTMLInputElement | undefined = $state();
let saveWithoutRegenerate = $state(false);
let showDiscardDialog = $state(false);
let isRouter = $derived(isRouterMode());
let currentConfig = $derived(config());
let pasteLongTextToFileLength = $derived.by(() => {
const n = Number(currentConfig.pasteLongTextToFileLen);
return Number.isNaN(n) ? Number(SETTING_CONFIG_DEFAULT.pasteLongTextToFileLen) : n;
});
let hasUnsavedChanges = $derived.by(() => {
if (editedContent !== originalContent) return true;
if (editedUploadedFiles.length > 0) return true;
const extrasChanged =
editedExtras.length !== originalExtras.length ||
editedExtras.some((extra, i) => extra !== originalExtras[i]);
if (extrasChanged) return true;
return false;
});
let hasAttachments = $derived(
(editedExtras && editedExtras.length > 0) ||
(editedUploadedFiles && editedUploadedFiles.length > 0)
);
let canSubmit = $derived(editedContent.trim().length > 0 || hasAttachments);
function getEditedAttachmentsModalities(): ModelModalities {
const modalities: ModelModalities = { vision: false, audio: false };
for (const extra of editedExtras) {
if (extra.type === AttachmentType.IMAGE) {
modalities.vision = true;
}
if (
extra.type === AttachmentType.PDF &&
'processedAsImages' in extra &&
extra.processedAsImages
) {
modalities.vision = true;
}
if (extra.type === AttachmentType.AUDIO) {
modalities.audio = true;
}
}
for (const file of editedUploadedFiles) {
const category = getFileTypeCategory(file.type) || getFileTypeCategoryByExtension(file.name);
if (category === FileTypeCategory.IMAGE) {
modalities.vision = true;
}
if (category === FileTypeCategory.AUDIO) {
modalities.audio = true;
}
}
return modalities;
}
function getRequiredModalities(): ModelModalities {
const beforeModalities = conversationsStore.getModalitiesUpToMessage(messageId);
const editedModalities = getEditedAttachmentsModalities();
return {
vision: beforeModalities.vision || editedModalities.vision,
audio: beforeModalities.audio || editedModalities.audio
};
}
const { handleModelChange } = useModelChangeValidation({
getRequiredModalities,
onValidationFailure: async (previousModelId) => {
if (previousModelId) {
await modelsStore.selectModelById(previousModelId);
}
}
});
function handleFileInputChange(event: Event) {
const input = event.target as HTMLInputElement;
if (!input.files || input.files.length === 0) return;
const files = Array.from(input.files);
processNewFiles(files);
input.value = '';
}
function handleGlobalKeydown(event: KeyboardEvent) {
if (event.key === 'Escape') {
event.preventDefault();
attemptCancel();
}
}
function attemptCancel() {
if (hasUnsavedChanges) {
showDiscardDialog = true;
} else {
onCancelEdit();
}
}
function handleRemoveExistingAttachment(index: number) {
if (!onEditedExtrasChange) return;
const newExtras = [...editedExtras];
newExtras.splice(index, 1);
onEditedExtrasChange(newExtras);
}
function handleRemoveUploadedFile(fileId: string) {
if (!onEditedUploadedFilesChange) return;
const newFiles = editedUploadedFiles.filter((f) => f.id !== fileId);
onEditedUploadedFilesChange(newFiles);
}
function handleSubmit() {
if (!canSubmit) return;
if (saveWithoutRegenerate && onSaveEditOnly) {
onSaveEditOnly();
} else {
onSaveEdit();
}
saveWithoutRegenerate = false;
}
async function processNewFiles(files: File[]) {
if (!onEditedUploadedFilesChange) return;
const { processFilesToChatUploaded } = await import('$lib/utils/browser-only');
const processed = await processFilesToChatUploaded(files);
onEditedUploadedFilesChange([...editedUploadedFiles, ...processed]);
}
function handlePaste(event: ClipboardEvent) {
if (!event.clipboardData) return;
const files = Array.from(event.clipboardData.items)
.filter((item) => item.kind === 'file')
.map((item) => item.getAsFile())
.filter((file): file is File => file !== null);
if (files.length > 0) {
event.preventDefault();
processNewFiles(files);
return;
}
const text = event.clipboardData.getData(MimeTypeText.PLAIN);
if (text.startsWith('"')) {
const parsed = parseClipboardContent(text);
if (parsed.textAttachments.length > 0) {
event.preventDefault();
onEditedContentChange(parsed.message);
const attachmentFiles = parsed.textAttachments.map(
(att) =>
new File([att.content], att.name, {
type: MimeTypeText.PLAIN
})
);
processNewFiles(attachmentFiles);
setTimeout(() => {
textareaElement?.focus();
}, 10);
return;
}
}
if (
text.length > 0 &&
pasteLongTextToFileLength > 0 &&
text.length > pasteLongTextToFileLength
) {
event.preventDefault();
const textFile = new File([text], 'Pasted', {
type: MimeTypeText.PLAIN
});
processNewFiles([textFile]);
}
}
$effect(() => {
if (textareaElement) {
autoResizeTextarea(textareaElement);
}
});
$effect(() => {
setEditModeActive(processNewFiles);
return () => {
clearEditMode();
};
});
</script>
<svelte:window onkeydown={handleGlobalKeydown} />
<input
bind:this={fileInputElement}
type="file"
multiple
class="hidden"
onchange={handleFileInputChange}
/>
<div
class="{INPUT_CLASSES} w-full max-w-[80%] overflow-hidden rounded-3xl backdrop-blur-md"
data-slot="edit-form"
>
<ChatAttachmentsList
attachments={editedExtras}
uploadedFiles={editedUploadedFiles}
readonly={false}
onFileRemove={(fileId) => {
if (fileId.startsWith('attachment-')) {
const index = parseInt(fileId.replace('attachment-', ''), 10);
if (!isNaN(index) && index >= 0 && index < editedExtras.length) {
handleRemoveExistingAttachment(index);
}
} else {
handleRemoveUploadedFile(fileId);
}
}}
limitToSingleRow
class="py-5"
style="scroll-padding: 1rem;"
/>
<div class="relative min-h-[48px] px-5 py-3">
<textarea
bind:this={textareaElement}
bind:value={editedContent}
class="field-sizing-content max-h-80 min-h-10 w-full resize-none bg-transparent text-sm outline-none"
onkeydown={onEditKeydown}
oninput={(e) => {
autoResizeTextarea(e.currentTarget);
onEditedContentChange(e.currentTarget.value);
}}
onpaste={handlePaste}
placeholder="Edit your message..."
></textarea>
<div class="flex w-full items-center gap-3" style="container-type: inline-size">
<Button
class="h-8 w-8 shrink-0 rounded-full bg-transparent p-0 text-muted-foreground hover:bg-foreground/10 hover:text-foreground"
onclick={() => fileInputElement?.click()}
type="button"
title="Add attachment"
>
<span class="sr-only">Attach files</span>
<Paperclip class="h-4 w-4" />
</Button>
<div class="flex-1"></div>
{#if isRouter}
<ModelsSelector
forceForegroundText={true}
useGlobalSelection={true}
onModelChange={handleModelChange}
/>
{/if}
<Button
class="h-8 w-8 shrink-0 rounded-full p-0"
onclick={handleSubmit}
disabled={!canSubmit}
type="button"
title={saveWithoutRegenerate ? 'Save changes' : 'Send and regenerate'}
>
<span class="sr-only">{saveWithoutRegenerate ? 'Save' : 'Send'}</span>
<ArrowUp class="h-5 w-5" />
</Button>
</div>
</div>
</div>
<div class="mt-2 flex w-full max-w-[80%] items-center justify-between">
{#if showSaveOnlyOption && onSaveEditOnly}
<div class="flex items-center gap-2">
<Switch id="save-only-switch" bind:checked={saveWithoutRegenerate} class="scale-75" />
<label for="save-only-switch" class="cursor-pointer text-xs text-muted-foreground">
Update without re-sending
</label>
</div>
{:else}
<div></div>
{/if}
<Button class="h-7 px-3 text-xs" onclick={attemptCancel} size="sm" variant="ghost">
<X class="mr-1 h-3 w-3" />
Cancel
</Button>
</div>
<DialogConfirmation
bind:open={showDiscardDialog}
title="Discard changes?"
description="You have unsaved changes. Are you sure you want to discard them?"
confirmText="Discard"
cancelText="Keep editing"
variant="destructive"
icon={AlertTriangle}
onConfirm={onCancelEdit}
onCancel={() => (showDiscardDialog = false)}
/>

View File

@ -1,18 +1,17 @@
<script lang="ts">
import { Check, X, Send } from '@lucide/svelte';
import { Card } from '$lib/components/ui/card';
import { Button } from '$lib/components/ui/button';
import { ChatAttachmentsList, MarkdownContent } from '$lib/components/app';
import { INPUT_CLASSES } from '$lib/constants/input-classes';
import { config } from '$lib/stores/settings.svelte';
import { autoResizeTextarea } from '$lib/utils';
import ChatMessageActions from './ChatMessageActions.svelte';
import ChatMessageEditForm from './ChatMessageEditForm.svelte';
interface Props {
class?: string;
message: DatabaseMessage;
isEditing: boolean;
editedContent: string;
editedExtras?: DatabaseMessageExtra[];
editedUploadedFiles?: ChatUploadedFile[];
siblingInfo?: ChatMessageSiblingInfo | null;
showDeleteDialog: boolean;
deletionInfo: {
@ -26,6 +25,8 @@
onSaveEditOnly?: () => void;
onEditKeydown: (event: KeyboardEvent) => void;
onEditedContentChange: (content: string) => void;
onEditedExtrasChange?: (extras: DatabaseMessageExtra[]) => void;
onEditedUploadedFilesChange?: (files: ChatUploadedFile[]) => void;
onCopy: () => void;
onEdit: () => void;
onDelete: () => void;
@ -40,6 +41,8 @@
message,
isEditing,
editedContent,
editedExtras = [],
editedUploadedFiles = [],
siblingInfo = null,
showDeleteDialog,
deletionInfo,
@ -48,6 +51,8 @@
onSaveEditOnly,
onEditKeydown,
onEditedContentChange,
onEditedExtrasChange,
onEditedUploadedFilesChange,
onCopy,
onEdit,
onDelete,
@ -61,12 +66,6 @@
let messageElement: HTMLElement | undefined = $state();
const currentConfig = config();
$effect(() => {
if (isEditing && textareaElement) {
autoResizeTextarea(textareaElement);
}
});
$effect(() => {
if (!messageElement || !message.content.trim()) return;
@ -98,44 +97,23 @@
role="group"
>
{#if isEditing}
<div class="w-full max-w-[80%]">
<textarea
bind:this={textareaElement}
bind:value={editedContent}
class="min-h-[60px] w-full resize-none rounded-2xl px-3 py-2 text-sm {INPUT_CLASSES}"
onkeydown={onEditKeydown}
oninput={(e) => {
autoResizeTextarea(e.currentTarget);
onEditedContentChange(e.currentTarget.value);
}}
placeholder="Edit your message..."
></textarea>
<div class="mt-2 flex justify-end gap-2">
<Button class="h-8 px-3" onclick={onCancelEdit} size="sm" variant="ghost">
<X class="mr-1 h-3 w-3" />
Cancel
</Button>
{#if onSaveEditOnly}
<Button
class="h-8 px-3"
onclick={onSaveEditOnly}
disabled={!editedContent.trim()}
size="sm"
variant="outline"
>
<Check class="mr-1 h-3 w-3" />
Save
</Button>
{/if}
<Button class="h-8 px-3" onclick={onSaveEdit} disabled={!editedContent.trim()} size="sm">
<Send class="mr-1 h-3 w-3" />
Send
</Button>
</div>
</div>
<ChatMessageEditForm
bind:textareaElement
messageId={message.id}
{editedContent}
{editedExtras}
{editedUploadedFiles}
originalContent={message.content}
originalExtras={message.extra}
showSaveOnlyOption={!!onSaveEditOnly}
{onCancelEdit}
{onSaveEdit}
{onSaveEditOnly}
{onEditKeydown}
{onEditedContentChange}
{onEditedExtrasChange}
{onEditedUploadedFilesChange}
/>
{:else}
{#if message.extra && message.extra.length > 0}
<div class="mb-2 max-w-[80%]">

View File

@ -66,10 +66,14 @@
await conversationsStore.navigateToSibling(siblingId);
}
async function handleEditWithBranching(message: DatabaseMessage, newContent: string) {
async function handleEditWithBranching(
message: DatabaseMessage,
newContent: string,
newExtras?: DatabaseMessageExtra[]
) {
onUserAction?.();
await chatStore.editMessageWithBranching(message.id, newContent);
await chatStore.editMessageWithBranching(message.id, newContent, newExtras);
refreshAllMessages();
}
@ -104,11 +108,12 @@
async function handleEditUserMessagePreserveResponses(
message: DatabaseMessage,
newContent: string
newContent: string,
newExtras?: DatabaseMessageExtra[]
) {
onUserAction?.();
await chatStore.editUserMessagePreserveResponses(message.id, newContent);
await chatStore.editUserMessagePreserveResponses(message.id, newContent, newExtras);
refreshAllMessages();
}

View File

@ -17,7 +17,13 @@
AUTO_SCROLL_INTERVAL,
INITIAL_SCROLL_DELAY
} from '$lib/constants/auto-scroll';
import { chatStore, errorDialog, isLoading } from '$lib/stores/chat.svelte';
import {
chatStore,
errorDialog,
isLoading,
isEditing,
getAddFilesHandler
} from '$lib/stores/chat.svelte';
import {
conversationsStore,
activeMessages,
@ -181,7 +187,18 @@
dragCounter = 0;
if (event.dataTransfer?.files) {
processFiles(Array.from(event.dataTransfer.files));
const files = Array.from(event.dataTransfer.files);
if (isEditing()) {
const handler = getAddFilesHandler();
if (handler) {
handler(files);
return;
}
}
processFiles(files);
}
}
@ -410,7 +427,7 @@
<div class="conversation-chat-form pointer-events-auto rounded-t-3xl pb-4">
<ChatForm
disabled={hasPropsError}
disabled={hasPropsError || isEditing()}
isLoading={isCurrentConversationLoading}
onFileRemove={handleFileRemove}
onFileUpload={handleFileUpload}

View File

@ -0,0 +1,7 @@
import Root from './switch.svelte';
export {
Root,
//
Root as Switch
};

View File

@ -0,0 +1,29 @@
<script lang="ts">
import { Switch as SwitchPrimitive } from 'bits-ui';
import { cn, type WithoutChildrenOrChild } from '$lib/components/ui/utils.js';
let {
ref = $bindable(null),
class: className,
checked = $bindable(false),
...restProps
}: WithoutChildrenOrChild<SwitchPrimitive.RootProps> = $props();
</script>
<SwitchPrimitive.Root
bind:ref
bind:checked
data-slot="switch"
class={cn(
'peer inline-flex h-[1.15rem] w-8 shrink-0 items-center rounded-full border border-transparent shadow-xs transition-all outline-none focus-visible:border-ring focus-visible:ring-[3px] focus-visible:ring-ring/50 disabled:cursor-not-allowed disabled:opacity-50 data-[state=checked]:bg-primary data-[state=unchecked]:bg-input dark:data-[state=unchecked]:bg-input/80',
className
)}
{...restProps}
>
<SwitchPrimitive.Thumb
data-slot="switch-thumb"
class={cn(
'pointer-events-none block size-4 rounded-full bg-background ring-0 transition-transform data-[state=checked]:translate-x-[calc(100%-2px)] data-[state=unchecked]:translate-x-0 dark:data-[state=checked]:bg-primary-foreground dark:data-[state=unchecked]:bg-foreground'
)}
/>
</SwitchPrimitive.Root>

View File

@ -74,6 +74,8 @@ class ChatStore {
private processingStates = new SvelteMap<string, ApiProcessingState | null>();
private activeConversationId = $state<string | null>(null);
private isStreamingActive = $state(false);
private isEditModeActive = $state(false);
private addFilesHandler: ((files: File[]) => void) | null = $state(null);
// ─────────────────────────────────────────────────────────────────────────────
// Loading State
@ -965,230 +967,9 @@ class ChatStore {
// Editing
// ─────────────────────────────────────────────────────────────────────────────
async editAssistantMessage(
messageId: string,
newContent: string,
shouldBranch: boolean
): Promise<void> {
const activeConv = conversationsStore.activeConversation;
if (!activeConv || this.isLoading) return;
const result = this.getMessageByIdWithRole(messageId, 'assistant');
if (!result) return;
const { message: msg, index: idx } = result;
try {
if (shouldBranch) {
const newMessage = await DatabaseService.createMessageBranch(
{
convId: msg.convId,
type: msg.type,
timestamp: Date.now(),
role: msg.role,
content: newContent,
thinking: msg.thinking || '',
toolCalls: msg.toolCalls || '',
children: [],
model: msg.model
},
msg.parent!
);
await conversationsStore.updateCurrentNode(newMessage.id);
} else {
await DatabaseService.updateMessage(msg.id, { content: newContent, timestamp: Date.now() });
await conversationsStore.updateCurrentNode(msg.id);
conversationsStore.updateMessageAtIndex(idx, {
content: newContent,
timestamp: Date.now()
});
}
conversationsStore.updateConversationTimestamp();
await conversationsStore.refreshActiveMessages();
} catch (error) {
console.error('Failed to edit assistant message:', error);
}
}
async editUserMessagePreserveResponses(messageId: string, newContent: string): Promise<void> {
const activeConv = conversationsStore.activeConversation;
if (!activeConv) return;
const result = this.getMessageByIdWithRole(messageId, 'user');
if (!result) return;
const { message: msg, index: idx } = result;
try {
await DatabaseService.updateMessage(messageId, {
content: newContent,
timestamp: Date.now()
});
conversationsStore.updateMessageAtIndex(idx, { content: newContent, timestamp: Date.now() });
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
if (rootMessage && msg.parent === rootMessage.id && newContent.trim()) {
await conversationsStore.updateConversationTitleWithConfirmation(
activeConv.id,
newContent.trim(),
conversationsStore.titleUpdateConfirmationCallback
);
}
conversationsStore.updateConversationTimestamp();
} catch (error) {
console.error('Failed to edit user message:', error);
}
}
async editMessageWithBranching(messageId: string, newContent: string): Promise<void> {
const activeConv = conversationsStore.activeConversation;
if (!activeConv || this.isLoading) return;
let result = this.getMessageByIdWithRole(messageId, 'user');
if (!result) {
result = this.getMessageByIdWithRole(messageId, 'system');
}
if (!result) return;
const { message: msg } = result;
try {
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
const isFirstUserMessage =
msg.role === 'user' && rootMessage && msg.parent === rootMessage.id;
const parentId = msg.parent || rootMessage?.id;
if (!parentId) return;
const newMessage = await DatabaseService.createMessageBranch(
{
convId: msg.convId,
type: msg.type,
timestamp: Date.now(),
role: msg.role,
content: newContent,
thinking: msg.thinking || '',
toolCalls: msg.toolCalls || '',
children: [],
extra: msg.extra ? JSON.parse(JSON.stringify(msg.extra)) : undefined,
model: msg.model
},
parentId
);
await conversationsStore.updateCurrentNode(newMessage.id);
conversationsStore.updateConversationTimestamp();
if (isFirstUserMessage && newContent.trim()) {
await conversationsStore.updateConversationTitleWithConfirmation(
activeConv.id,
newContent.trim(),
conversationsStore.titleUpdateConfirmationCallback
);
}
await conversationsStore.refreshActiveMessages();
if (msg.role === 'user') {
await this.generateResponseForMessage(newMessage.id);
}
} catch (error) {
console.error('Failed to edit message with branching:', error);
}
}
async regenerateMessageWithBranching(messageId: string, modelOverride?: string): Promise<void> {
const activeConv = conversationsStore.activeConversation;
if (!activeConv || this.isLoading) return;
try {
const idx = conversationsStore.findMessageIndex(messageId);
if (idx === -1) return;
const msg = conversationsStore.activeMessages[idx];
if (msg.role !== 'assistant') return;
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
const parentMessage = allMessages.find((m) => m.id === msg.parent);
if (!parentMessage) return;
this.setChatLoading(activeConv.id, true);
this.clearChatStreaming(activeConv.id);
const newAssistantMessage = await DatabaseService.createMessageBranch(
{
convId: activeConv.id,
type: 'text',
timestamp: Date.now(),
role: 'assistant',
content: '',
thinking: '',
toolCalls: '',
children: [],
model: null
},
parentMessage.id
);
await conversationsStore.updateCurrentNode(newAssistantMessage.id);
conversationsStore.updateConversationTimestamp();
await conversationsStore.refreshActiveMessages();
const conversationPath = filterByLeafNodeId(
allMessages,
parentMessage.id,
false
) as DatabaseMessage[];
// Use modelOverride if provided, otherwise use the original message's model
// If neither is available, don't pass model (will use global selection)
const modelToUse = modelOverride || msg.model || undefined;
await this.streamChatCompletion(
conversationPath,
newAssistantMessage,
undefined,
undefined,
modelToUse
);
} catch (error) {
if (!this.isAbortError(error))
console.error('Failed to regenerate message with branching:', error);
this.setChatLoading(activeConv?.id || '', false);
}
}
private async generateResponseForMessage(userMessageId: string): Promise<void> {
const activeConv = conversationsStore.activeConversation;
if (!activeConv) return;
this.errorDialogState = null;
this.setChatLoading(activeConv.id, true);
this.clearChatStreaming(activeConv.id);
try {
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
const conversationPath = filterByLeafNodeId(
allMessages,
userMessageId,
false
) as DatabaseMessage[];
const assistantMessage = await DatabaseService.createMessageBranch(
{
convId: activeConv.id,
type: 'text',
timestamp: Date.now(),
role: 'assistant',
content: '',
thinking: '',
toolCalls: '',
children: [],
model: null
},
userMessageId
);
conversationsStore.addMessageToActive(assistantMessage);
await this.streamChatCompletion(conversationPath, assistantMessage);
} catch (error) {
console.error('Failed to generate response:', error);
this.setChatLoading(activeConv.id, false);
}
clearEditMode(): void {
this.isEditModeActive = false;
this.addFilesHandler = null;
}
async continueAssistantMessage(messageId: string): Promise<void> {
@ -1340,19 +1121,284 @@ class ChatStore {
}
}
public isChatLoadingPublic(convId: string): boolean {
return this.isChatLoading(convId);
async editAssistantMessage(
messageId: string,
newContent: string,
shouldBranch: boolean
): Promise<void> {
const activeConv = conversationsStore.activeConversation;
if (!activeConv || this.isLoading) return;
const result = this.getMessageByIdWithRole(messageId, 'assistant');
if (!result) return;
const { message: msg, index: idx } = result;
try {
if (shouldBranch) {
const newMessage = await DatabaseService.createMessageBranch(
{
convId: msg.convId,
type: msg.type,
timestamp: Date.now(),
role: msg.role,
content: newContent,
thinking: msg.thinking || '',
toolCalls: msg.toolCalls || '',
children: [],
model: msg.model
},
msg.parent!
);
await conversationsStore.updateCurrentNode(newMessage.id);
} else {
await DatabaseService.updateMessage(msg.id, { content: newContent });
await conversationsStore.updateCurrentNode(msg.id);
conversationsStore.updateMessageAtIndex(idx, {
content: newContent
});
}
conversationsStore.updateConversationTimestamp();
await conversationsStore.refreshActiveMessages();
} catch (error) {
console.error('Failed to edit assistant message:', error);
}
}
async editUserMessagePreserveResponses(
messageId: string,
newContent: string,
newExtras?: DatabaseMessageExtra[]
): Promise<void> {
const activeConv = conversationsStore.activeConversation;
if (!activeConv) return;
const result = this.getMessageByIdWithRole(messageId, 'user');
if (!result) return;
const { message: msg, index: idx } = result;
try {
const updateData: Partial<DatabaseMessage> = {
content: newContent
};
// Update extras if provided (including empty array to clear attachments)
// Deep clone to avoid Proxy objects from Svelte reactivity
if (newExtras !== undefined) {
updateData.extra = JSON.parse(JSON.stringify(newExtras));
}
await DatabaseService.updateMessage(messageId, updateData);
conversationsStore.updateMessageAtIndex(idx, updateData);
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
if (rootMessage && msg.parent === rootMessage.id && newContent.trim()) {
await conversationsStore.updateConversationTitleWithConfirmation(
activeConv.id,
newContent.trim(),
conversationsStore.titleUpdateConfirmationCallback
);
}
conversationsStore.updateConversationTimestamp();
} catch (error) {
console.error('Failed to edit user message:', error);
}
}
async editMessageWithBranching(
messageId: string,
newContent: string,
newExtras?: DatabaseMessageExtra[]
): Promise<void> {
const activeConv = conversationsStore.activeConversation;
if (!activeConv || this.isLoading) return;
let result = this.getMessageByIdWithRole(messageId, 'user');
if (!result) {
result = this.getMessageByIdWithRole(messageId, 'system');
}
if (!result) return;
const { message: msg } = result;
try {
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
const isFirstUserMessage =
msg.role === 'user' && rootMessage && msg.parent === rootMessage.id;
const parentId = msg.parent || rootMessage?.id;
if (!parentId) return;
// Use newExtras if provided, otherwise copy existing extras
// Deep clone to avoid Proxy objects from Svelte reactivity
const extrasToUse =
newExtras !== undefined
? JSON.parse(JSON.stringify(newExtras))
: msg.extra
? JSON.parse(JSON.stringify(msg.extra))
: undefined;
const newMessage = await DatabaseService.createMessageBranch(
{
convId: msg.convId,
type: msg.type,
timestamp: Date.now(),
role: msg.role,
content: newContent,
thinking: msg.thinking || '',
toolCalls: msg.toolCalls || '',
children: [],
extra: extrasToUse,
model: msg.model
},
parentId
);
await conversationsStore.updateCurrentNode(newMessage.id);
conversationsStore.updateConversationTimestamp();
if (isFirstUserMessage && newContent.trim()) {
await conversationsStore.updateConversationTitleWithConfirmation(
activeConv.id,
newContent.trim(),
conversationsStore.titleUpdateConfirmationCallback
);
}
await conversationsStore.refreshActiveMessages();
if (msg.role === 'user') {
await this.generateResponseForMessage(newMessage.id);
}
} catch (error) {
console.error('Failed to edit message with branching:', error);
}
}
async regenerateMessageWithBranching(messageId: string, modelOverride?: string): Promise<void> {
const activeConv = conversationsStore.activeConversation;
if (!activeConv || this.isLoading) return;
try {
const idx = conversationsStore.findMessageIndex(messageId);
if (idx === -1) return;
const msg = conversationsStore.activeMessages[idx];
if (msg.role !== 'assistant') return;
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
const parentMessage = allMessages.find((m) => m.id === msg.parent);
if (!parentMessage) return;
this.setChatLoading(activeConv.id, true);
this.clearChatStreaming(activeConv.id);
const newAssistantMessage = await DatabaseService.createMessageBranch(
{
convId: activeConv.id,
type: 'text',
timestamp: Date.now(),
role: 'assistant',
content: '',
thinking: '',
toolCalls: '',
children: [],
model: null
},
parentMessage.id
);
await conversationsStore.updateCurrentNode(newAssistantMessage.id);
conversationsStore.updateConversationTimestamp();
await conversationsStore.refreshActiveMessages();
const conversationPath = filterByLeafNodeId(
allMessages,
parentMessage.id,
false
) as DatabaseMessage[];
// Use modelOverride if provided, otherwise use the original message's model
// If neither is available, don't pass model (will use global selection)
const modelToUse = modelOverride || msg.model || undefined;
await this.streamChatCompletion(
conversationPath,
newAssistantMessage,
undefined,
undefined,
modelToUse
);
} catch (error) {
if (!this.isAbortError(error))
console.error('Failed to regenerate message with branching:', error);
this.setChatLoading(activeConv?.id || '', false);
}
}
private async generateResponseForMessage(userMessageId: string): Promise<void> {
const activeConv = conversationsStore.activeConversation;
if (!activeConv) return;
this.errorDialogState = null;
this.setChatLoading(activeConv.id, true);
this.clearChatStreaming(activeConv.id);
try {
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
const conversationPath = filterByLeafNodeId(
allMessages,
userMessageId,
false
) as DatabaseMessage[];
const assistantMessage = await DatabaseService.createMessageBranch(
{
convId: activeConv.id,
type: 'text',
timestamp: Date.now(),
role: 'assistant',
content: '',
thinking: '',
toolCalls: '',
children: [],
model: null
},
userMessageId
);
conversationsStore.addMessageToActive(assistantMessage);
await this.streamChatCompletion(conversationPath, assistantMessage);
} catch (error) {
console.error('Failed to generate response:', error);
this.setChatLoading(activeConv.id, false);
}
}
getAddFilesHandler(): ((files: File[]) => void) | null {
return this.addFilesHandler;
}
public getAllLoadingChats(): string[] {
return Array.from(this.chatLoadingStates.keys());
}
public getAllStreamingChats(): string[] {
return Array.from(this.chatStreamingStates.keys());
}
public getChatStreamingPublic(
convId: string
): { response: string; messageId: string } | undefined {
return this.getChatStreaming(convId);
}
public getAllLoadingChats(): string[] {
return Array.from(this.chatLoadingStates.keys());
public isChatLoadingPublic(convId: string): boolean {
return this.isChatLoading(convId);
}
public getAllStreamingChats(): string[] {
return Array.from(this.chatStreamingStates.keys());
isEditing(): boolean {
return this.isEditModeActive;
}
setEditModeActive(handler: (files: File[]) => void): void {
this.isEditModeActive = true;
this.addFilesHandler = handler;
}
// ─────────────────────────────────────────────────────────────────────────────
@ -1418,13 +1464,17 @@ class ChatStore {
export const chatStore = new ChatStore();
export const isLoading = () => chatStore.isLoading;
export const activeProcessingState = () => chatStore.activeProcessingState;
export const clearEditMode = () => chatStore.clearEditMode();
export const currentResponse = () => chatStore.currentResponse;
export const errorDialog = () => chatStore.errorDialogState;
export const activeProcessingState = () => chatStore.activeProcessingState;
export const isChatStreaming = () => chatStore.isStreaming();
export const isChatLoading = (convId: string) => chatStore.isChatLoadingPublic(convId);
export const getChatStreaming = (convId: string) => chatStore.getChatStreamingPublic(convId);
export const getAddFilesHandler = () => chatStore.getAddFilesHandler();
export const getAllLoadingChats = () => chatStore.getAllLoadingChats();
export const getAllStreamingChats = () => chatStore.getAllStreamingChats();
export const getChatStreaming = (convId: string) => chatStore.getChatStreamingPublic(convId);
export const isChatLoading = (convId: string) => chatStore.isChatLoadingPublic(convId);
export const isChatStreaming = () => chatStore.isStreaming();
export const isEditing = () => chatStore.isEditing();
export const isLoading = () => chatStore.isLoading;
export const setEditModeActive = (handler: (files: File[]) => void) =>
chatStore.setEditModeActive(handler);