Merge branch 'ggml-org:master' into power-law-sampler

This commit is contained in:
ddh0 2025-12-19 17:53:19 -06:00 committed by GitHub
commit f4703d422c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
55 changed files with 2402 additions and 870 deletions

View File

@ -70,6 +70,7 @@ jobs:
with: with:
key: macOS-latest-cmake-arm64 key: macOS-latest-cmake-arm64
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build - name: Build
id: cmake_build id: cmake_build
@ -106,6 +107,7 @@ jobs:
with: with:
key: macOS-latest-cmake-x64 key: macOS-latest-cmake-x64
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build - name: Build
id: cmake_build id: cmake_build
@ -142,6 +144,7 @@ jobs:
with: with:
key: macOS-latest-cmake-arm64-webgpu key: macOS-latest-cmake-arm64-webgpu
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dawn Dependency - name: Dawn Dependency
id: dawn-depends id: dawn-depends
@ -195,6 +198,7 @@ jobs:
with: with:
key: ubuntu-cpu-cmake-${{ matrix.build }} key: ubuntu-cpu-cmake-${{ matrix.build }}
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build Dependencies - name: Build Dependencies
id: build_depends id: build_depends
@ -276,6 +280,7 @@ jobs:
with: with:
key: ubuntu-latest-cmake-sanitizer-${{ matrix.sanitizer }} key: ubuntu-latest-cmake-sanitizer-${{ matrix.sanitizer }}
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies - name: Dependencies
id: depends id: depends
@ -396,6 +401,7 @@ jobs:
with: with:
key: ubuntu-24-cmake-vulkan-deb key: ubuntu-24-cmake-vulkan-deb
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies - name: Dependencies
id: depends id: depends
@ -431,6 +437,7 @@ jobs:
with: with:
key: ubuntu-24-cmake-vulkan key: ubuntu-24-cmake-vulkan
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies - name: Dependencies
id: depends id: depends
@ -490,6 +497,7 @@ jobs:
with: with:
key: ubuntu-24-cmake-webgpu key: ubuntu-24-cmake-webgpu
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies - name: Dependencies
id: depends id: depends
@ -562,6 +570,7 @@ jobs:
with: with:
key: ubuntu-latest-wasm-webgpu key: ubuntu-latest-wasm-webgpu
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Install Emscripten - name: Install Emscripten
run: | run: |
@ -609,6 +618,7 @@ jobs:
with: with:
key: ubuntu-22-cmake-hip key: ubuntu-22-cmake-hip
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build with native CMake HIP support - name: Build with native CMake HIP support
id: cmake_build id: cmake_build
@ -641,6 +651,7 @@ jobs:
with: with:
key: ubuntu-22-cmake-musa key: ubuntu-22-cmake-musa
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build with native CMake MUSA support - name: Build with native CMake MUSA support
id: cmake_build id: cmake_build
@ -688,6 +699,7 @@ jobs:
with: with:
key: ubuntu-22-cmake-sycl key: ubuntu-22-cmake-sycl
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build - name: Build
id: cmake_build id: cmake_build
@ -738,6 +750,7 @@ jobs:
with: with:
key: ubuntu-22-cmake-sycl-fp16 key: ubuntu-22-cmake-sycl-fp16
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build - name: Build
id: cmake_build id: cmake_build
@ -771,6 +784,7 @@ jobs:
with: with:
key: macOS-latest-cmake-ios key: macOS-latest-cmake-ios
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build - name: Build
id: cmake_build id: cmake_build
@ -802,6 +816,7 @@ jobs:
with: with:
key: macOS-latest-cmake-tvos key: macOS-latest-cmake-tvos
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build - name: Build
id: cmake_build id: cmake_build
@ -863,6 +878,7 @@ jobs:
with: with:
key: macOS-latest-swift key: macOS-latest-swift
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Download xcframework artifact - name: Download xcframework artifact
uses: actions/download-artifact@v4 uses: actions/download-artifact@v4
@ -905,6 +921,7 @@ jobs:
key: windows-msys2 key: windows-msys2
variant: ccache variant: ccache
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Setup ${{ matrix.sys }} - name: Setup ${{ matrix.sys }}
uses: msys2/setup-msys2@v2 uses: msys2/setup-msys2@v2
@ -973,6 +990,7 @@ jobs:
key: windows-latest-cmake-${{ matrix.build }} key: windows-latest-cmake-${{ matrix.build }}
variant: ccache variant: ccache
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Download OpenBLAS - name: Download OpenBLAS
id: get_openblas id: get_openblas
@ -1077,6 +1095,7 @@ jobs:
with: with:
key: ubuntu-latest-cmake-cuda key: ubuntu-latest-cmake-cuda
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build with CMake - name: Build with CMake
run: | run: |
@ -1109,6 +1128,7 @@ jobs:
key: windows-cuda-${{ matrix.cuda }} key: windows-cuda-${{ matrix.cuda }}
variant: ccache variant: ccache
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Install Cuda Toolkit - name: Install Cuda Toolkit
uses: ./.github/actions/windows-setup-cuda uses: ./.github/actions/windows-setup-cuda
@ -1160,6 +1180,7 @@ jobs:
key: windows-latest-cmake-sycl key: windows-latest-cmake-sycl
variant: ccache variant: ccache
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Install - name: Install
run: | run: |
@ -1221,6 +1242,7 @@ jobs:
with: with:
key: ${{ github.job }} key: ${{ github.job }}
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build - name: Build
id: cmake_build id: cmake_build
@ -1466,6 +1488,7 @@ jobs:
with: with:
key: ggml-ci-x64-cpu-low-perf key: ggml-ci-x64-cpu-low-perf
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies - name: Dependencies
id: depends id: depends
@ -1491,6 +1514,7 @@ jobs:
with: with:
key: ggml-ci-arm64-cpu-low-perf key: ggml-ci-arm64-cpu-low-perf
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies - name: Dependencies
id: depends id: depends
@ -1516,6 +1540,7 @@ jobs:
with: with:
key: ggml-ci-x64-cpu-high-perf key: ggml-ci-x64-cpu-high-perf
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies - name: Dependencies
id: depends id: depends
@ -1541,6 +1566,7 @@ jobs:
with: with:
key: ggml-ci-arm64-cpu-high-perf key: ggml-ci-arm64-cpu-high-perf
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies - name: Dependencies
id: depends id: depends
@ -1566,6 +1592,7 @@ jobs:
with: with:
key: ggml-ci-arm64-cpu-high-perf-sve key: ggml-ci-arm64-cpu-high-perf-sve
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies - name: Dependencies
id: depends id: depends
@ -1701,6 +1728,7 @@ jobs:
with: with:
key: ggml-ci-arm64-cpu-kleidiai key: ggml-ci-arm64-cpu-kleidiai
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies - name: Dependencies
id: depends id: depends
@ -2084,6 +2112,7 @@ jobs:
with: with:
key: ggml-ci-arm64-graviton4-kleidiai key: ggml-ci-arm64-graviton4-kleidiai
evict-old-files: 1d evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Test - name: Test
id: ggml-ci id: ggml-ci

View File

@ -66,16 +66,9 @@ jobs:
id: pack_artifacts id: pack_artifacts
run: | run: |
cp LICENSE ./build/bin/ cp LICENSE ./build/bin/
zip -y -r llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip ./build/bin/*
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin . tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- name: Upload artifacts (zip) - name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip
name: llama-bin-macos-arm64.zip
- name: Upload artifacts (tar)
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz
@ -127,16 +120,9 @@ jobs:
id: pack_artifacts id: pack_artifacts
run: | run: |
cp LICENSE ./build/bin/ cp LICENSE ./build/bin/
zip -y -r llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip ./build/bin/*
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin . tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- name: Upload artifacts (zip) - name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip
name: llama-bin-macos-x64.zip
- name: Upload artifacts (tar)
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz
@ -196,16 +182,9 @@ jobs:
id: pack_artifacts id: pack_artifacts
run: | run: |
cp LICENSE ./build/bin/ cp LICENSE ./build/bin/
zip -y -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip ./build/bin/*
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin . tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- name: Upload artifacts (zip) - name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip
name: llama-bin-ubuntu-${{ matrix.build }}.zip
- name: Upload artifacts (tar)
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz
@ -256,16 +235,9 @@ jobs:
id: pack_artifacts id: pack_artifacts
run: | run: |
cp LICENSE ./build/bin/ cp LICENSE ./build/bin/
zip -y -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip ./build/bin/*
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin . tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- name: Upload artifacts (zip) - name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip
name: llama-bin-ubuntu-vulkan-x64.zip
- name: Upload artifacts (tar)
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz
@ -716,16 +688,9 @@ jobs:
- name: Pack artifacts - name: Pack artifacts
id: pack_artifacts id: pack_artifacts
run: | run: |
zip -y -r llama-${{ steps.tag.outputs.name }}-xcframework.zip build-apple/llama.xcframework
tar -czvf llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz -C build-apple llama.xcframework tar -czvf llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz -C build-apple llama.xcframework
- name: Upload artifacts (zip) - name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-xcframework.zip
name: llama-${{ steps.tag.outputs.name }}-xcframework.zip
- name: Upload artifacts (tar)
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
path: llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz path: llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz
@ -797,7 +762,7 @@ jobs:
cp LICENSE ./build/bin/ cp LICENSE ./build/bin/
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin . tar -czvf llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- name: Upload artifacts (tar) - name: Upload artifacts
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz
@ -889,9 +854,6 @@ jobs:
with: with:
tag_name: ${{ steps.tag.outputs.name }} tag_name: ${{ steps.tag.outputs.name }}
body: | body: |
> [!WARNING]
> **Release Format Update**: Linux releases will soon use .tar.gz archives instead of .zip. Please make the necessary changes to your deployment scripts.
<details open> <details open>
${{ github.event.head_commit.message }} ${{ github.event.head_commit.message }}
@ -911,8 +873,8 @@ jobs:
**Windows:** **Windows:**
- [Windows x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-x64.zip) - [Windows x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-x64.zip)
- [Windows arm64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-arm64.zip) - [Windows arm64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-arm64.zip)
- [Windows x64 (CUDA 12)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-12.4-x64.zip) - [Windows x64 (CUDA 12)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-12.4-x64.zip) - [CUDA 12.4 DLLs](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-12.4-x64.zip)
- [Windows x64 (CUDA 13)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-13.1-x64.zip) - [Windows x64 (CUDA 13)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-13.1-x64.zip) - [CUDA 13.1 DLLs](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-13.1-x64.zip)
- [Windows x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-vulkan-x64.zip) - [Windows x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-vulkan-x64.zip)
- [Windows x64 (SYCL)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip) - [Windows x64 (SYCL)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip)
- [Windows x64 (HIP)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-hip-radeon-x64.zip) - [Windows x64 (HIP)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-hip-radeon-x64.zip)

View File

@ -772,6 +772,11 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
} }
auto opt = *arg_to_options[arg]; auto opt = *arg_to_options[arg];
std::string val; std::string val;
if (opt.value_hint == nullptr && opt.value_hint_2 == nullptr) {
// bool arg (need to reverse the meaning for negative args)
bool is_neg = std::find(opt.args_neg.begin(), opt.args_neg.end(), arg) != opt.args_neg.end();
val = is_neg ? "0" : "1";
}
if (opt.value_hint != nullptr) { if (opt.value_hint != nullptr) {
// arg with single value // arg with single value
check_arg(i); check_arg(i);
@ -1139,7 +1144,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
} }
).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); ).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg( add_opt(common_arg(
{"--cache-ram", "-cram"}, "N", {"-cram", "--cache-ram"}, "N",
string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)" string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)", params.cache_ram_mib), "[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)", params.cache_ram_mib),
[](common_params & params, int value) { [](common_params & params, int value) {
@ -1147,7 +1152,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
} }
).set_env("LLAMA_ARG_CACHE_RAM").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); ).set_env("LLAMA_ARG_CACHE_RAM").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg( add_opt(common_arg(
{"--kv-unified", "-kvu"}, {"-kvu", "--kv-unified"},
"use single unified KV buffer shared across all sequences (default: enabled if number of slots is auto)", "use single unified KV buffer shared across all sequences (default: enabled if number of slots is auto)",
[](common_params & params) { [](common_params & params) {
params.kv_unified = true; params.kv_unified = true;
@ -1196,7 +1201,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.system_prompt = value; params.system_prompt = value;
} }
).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_DIFFUSION})); ).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_DIFFUSION, LLAMA_EXAMPLE_MTMD}));
add_opt(common_arg( add_opt(common_arg(
{"--perf"}, {"--perf"},
{"--no-perf"}, {"--no-perf"},
@ -1415,7 +1420,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
{"--sampling-seq", "--sampler-seq"}, "SEQUENCE", {"--sampler-seq", "--sampling-seq"}, "SEQUENCE",
string_format("simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str()), string_format("simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str()),
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.sampling.samplers = common_sampler_types_from_chars(value); params.sampling.samplers = common_sampler_types_from_chars(value);
@ -2091,26 +2096,26 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
} }
)); ));
add_opt(common_arg( add_opt(common_arg(
{"--override-tensor", "-ot"}, "<tensor name pattern>=<buffer type>,...", {"-ot", "--override-tensor"}, "<tensor name pattern>=<buffer type>,...",
"override tensor buffer type", [](common_params & params, const std::string & value) { "override tensor buffer type", [](common_params & params, const std::string & value) {
parse_tensor_buffer_overrides(value, params.tensor_buft_overrides); parse_tensor_buffer_overrides(value, params.tensor_buft_overrides);
} }
)); ));
add_opt(common_arg( add_opt(common_arg(
{"--override-tensor-draft", "-otd"}, "<tensor name pattern>=<buffer type>,...", {"-otd", "--override-tensor-draft"}, "<tensor name pattern>=<buffer type>,...",
"override tensor buffer type for draft model", [](common_params & params, const std::string & value) { "override tensor buffer type for draft model", [](common_params & params, const std::string & value) {
parse_tensor_buffer_overrides(value, params.speculative.tensor_buft_overrides); parse_tensor_buffer_overrides(value, params.speculative.tensor_buft_overrides);
} }
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg( add_opt(common_arg(
{"--cpu-moe", "-cmoe"}, {"-cmoe", "--cpu-moe"},
"keep all Mixture of Experts (MoE) weights in the CPU", "keep all Mixture of Experts (MoE) weights in the CPU",
[](common_params & params) { [](common_params & params) {
params.tensor_buft_overrides.push_back(llm_ffn_exps_cpu_override()); params.tensor_buft_overrides.push_back(llm_ffn_exps_cpu_override());
} }
).set_env("LLAMA_ARG_CPU_MOE")); ).set_env("LLAMA_ARG_CPU_MOE"));
add_opt(common_arg( add_opt(common_arg(
{"--n-cpu-moe", "-ncmoe"}, "N", {"-ncmoe", "--n-cpu-moe"}, "N",
"keep the Mixture of Experts (MoE) weights of the first N layers in the CPU", "keep the Mixture of Experts (MoE) weights of the first N layers in the CPU",
[](common_params & params, int value) { [](common_params & params, int value) {
if (value < 0) { if (value < 0) {
@ -2125,14 +2130,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
} }
).set_env("LLAMA_ARG_N_CPU_MOE")); ).set_env("LLAMA_ARG_N_CPU_MOE"));
add_opt(common_arg( add_opt(common_arg(
{"--cpu-moe-draft", "-cmoed"}, {"-cmoed", "--cpu-moe-draft"},
"keep all Mixture of Experts (MoE) weights in the CPU for the draft model", "keep all Mixture of Experts (MoE) weights in the CPU for the draft model",
[](common_params & params) { [](common_params & params) {
params.speculative.tensor_buft_overrides.push_back(llm_ffn_exps_cpu_override()); params.speculative.tensor_buft_overrides.push_back(llm_ffn_exps_cpu_override());
} }
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_CPU_MOE_DRAFT")); ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_CPU_MOE_DRAFT"));
add_opt(common_arg( add_opt(common_arg(
{"--n-cpu-moe-draft", "-ncmoed"}, "N", {"-ncmoed", "--n-cpu-moe-draft"}, "N",
"keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model", "keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model",
[](common_params & params, int value) { [](common_params & params, int value) {
if (value < 0) { if (value < 0) {
@ -2660,7 +2665,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
} }
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS")); ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
add_opt(common_arg( add_opt(common_arg(
{"--reranking", "--rerank"}, {"--rerank", "--reranking"},
string_format("enable reranking endpoint on server (default: %s)", "disabled"), string_format("enable reranking endpoint on server (default: %s)", "disabled"),
[](common_params & params) { [](common_params & params) {
params.embedding = true; params.embedding = true;
@ -3131,7 +3136,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
} }
).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg( add_opt(common_arg(
{"--draft-max", "--draft", "--draft-n"}, "N", {"--draft", "--draft-n", "--draft-max"}, "N",
string_format("number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max), string_format("number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max),
[](common_params & params, int value) { [](common_params & params, int value) {
params.speculative.n_max = value; params.speculative.n_max = value;

View File

@ -2,6 +2,7 @@
#include "preset.h" #include "preset.h"
#include "peg-parser.h" #include "peg-parser.h"
#include "log.h" #include "log.h"
#include "download.h"
#include <fstream> #include <fstream>
#include <sstream> #include <sstream>
@ -15,9 +16,13 @@ static std::string rm_leading_dashes(const std::string & str) {
return str.substr(pos); return str.substr(pos);
} }
std::vector<std::string> common_preset::to_args() const { std::vector<std::string> common_preset::to_args(const std::string & bin_path) const {
std::vector<std::string> args; std::vector<std::string> args;
if (!bin_path.empty()) {
args.push_back(bin_path);
}
for (const auto & [opt, value] : options) { for (const auto & [opt, value] : options) {
args.push_back(opt.args.back()); // use the last arg as the main arg args.push_back(opt.args.back()); // use the last arg as the main arg
if (opt.value_hint == nullptr && opt.value_hint_2 == nullptr) { if (opt.value_hint == nullptr && opt.value_hint_2 == nullptr) {
@ -63,6 +68,52 @@ std::string common_preset::to_ini() const {
return ss.str(); return ss.str();
} }
void common_preset::set_option(const common_preset_context & ctx, const std::string & env, const std::string & value) {
// try if option exists, update it
for (auto & [opt, val] : options) {
if (opt.env && env == opt.env) {
val = value;
return;
}
}
// if option does not exist, we need to add it
if (ctx.key_to_opt.find(env) == ctx.key_to_opt.end()) {
throw std::runtime_error(string_format(
"%s: option with env '%s' not found in ctx_params",
__func__, env.c_str()
));
}
options[ctx.key_to_opt.at(env)] = value;
}
void common_preset::unset_option(const std::string & env) {
for (auto it = options.begin(); it != options.end(); ) {
const common_arg & opt = it->first;
if (opt.env && env == opt.env) {
it = options.erase(it);
return;
} else {
++it;
}
}
}
bool common_preset::get_option(const std::string & env, std::string & value) const {
for (const auto & [opt, val] : options) {
if (opt.env && env == opt.env) {
value = val;
return true;
}
}
return false;
}
void common_preset::merge(const common_preset & other) {
for (const auto & [opt, val] : other.options) {
options[opt] = val; // overwrite existing options
}
}
static std::map<std::string, std::map<std::string, std::string>> parse_ini_from_file(const std::string & path) { static std::map<std::string, std::map<std::string, std::string>> parse_ini_from_file(const std::string & path) {
std::map<std::string, std::map<std::string, std::string>> parsed; std::map<std::string, std::map<std::string, std::string>> parsed;
@ -172,9 +223,12 @@ static std::string parse_bool_arg(const common_arg & arg, const std::string & ke
return value; return value;
} }
common_presets common_presets_load(const std::string & path, common_params_context & ctx_params) { common_preset_context::common_preset_context(llama_example ex)
: ctx_params(common_params_parser_init(default_params, ex)),
key_to_opt(get_map_key_opt(ctx_params)) {}
common_presets common_preset_context::load_from_ini(const std::string & path, common_preset & global) const {
common_presets out; common_presets out;
auto key_to_opt = get_map_key_opt(ctx_params);
auto ini_data = parse_ini_from_file(path); auto ini_data = parse_ini_from_file(path);
for (auto section : ini_data) { for (auto section : ini_data) {
@ -188,7 +242,7 @@ common_presets common_presets_load(const std::string & path, common_params_conte
for (const auto & [key, value] : section.second) { for (const auto & [key, value] : section.second) {
LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str()); LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str());
if (key_to_opt.find(key) != key_to_opt.end()) { if (key_to_opt.find(key) != key_to_opt.end()) {
auto & opt = key_to_opt[key]; const auto & opt = key_to_opt.at(key);
if (is_bool_arg(opt)) { if (is_bool_arg(opt)) {
preset.options[opt] = parse_bool_arg(opt, key, value); preset.options[opt] = parse_bool_arg(opt, key, value);
} else { } else {
@ -199,8 +253,137 @@ common_presets common_presets_load(const std::string & path, common_params_conte
// TODO: maybe warn about unknown key? // TODO: maybe warn about unknown key?
} }
} }
if (preset.name == "*") {
// handle global preset
global = preset;
} else {
out[preset.name] = preset;
}
}
return out;
}
common_presets common_preset_context::load_from_cache() const {
common_presets out;
auto cached_models = common_list_cached_models();
for (const auto & model : cached_models) {
common_preset preset;
preset.name = model.to_string();
preset.set_option(*this, "LLAMA_ARG_HF_REPO", model.to_string());
out[preset.name] = preset; out[preset.name] = preset;
} }
return out; return out;
} }
struct local_model {
std::string name;
std::string path;
std::string path_mmproj;
};
common_presets common_preset_context::load_from_models_dir(const std::string & models_dir) const {
if (!std::filesystem::exists(models_dir) || !std::filesystem::is_directory(models_dir)) {
throw std::runtime_error(string_format("error: '%s' does not exist or is not a directory\n", models_dir.c_str()));
}
std::vector<local_model> models;
auto scan_subdir = [&models](const std::string & subdir_path, const std::string & name) {
auto files = fs_list(subdir_path, false);
common_file_info model_file;
common_file_info first_shard_file;
common_file_info mmproj_file;
for (const auto & file : files) {
if (string_ends_with(file.name, ".gguf")) {
if (file.name.find("mmproj") != std::string::npos) {
mmproj_file = file;
} else if (file.name.find("-00001-of-") != std::string::npos) {
first_shard_file = file;
} else {
model_file = file;
}
}
}
// single file model
local_model model{
/* name */ name,
/* path */ first_shard_file.path.empty() ? model_file.path : first_shard_file.path,
/* path_mmproj */ mmproj_file.path // can be empty
};
if (!model.path.empty()) {
models.push_back(model);
}
};
auto files = fs_list(models_dir, true);
for (const auto & file : files) {
if (file.is_dir) {
scan_subdir(file.path, file.name);
} else if (string_ends_with(file.name, ".gguf")) {
// single file model
std::string name = file.name;
string_replace_all(name, ".gguf", "");
local_model model{
/* name */ name,
/* path */ file.path,
/* path_mmproj */ ""
};
models.push_back(model);
}
}
// convert local models to presets
common_presets out;
for (const auto & model : models) {
common_preset preset;
preset.name = model.name;
preset.set_option(*this, "LLAMA_ARG_MODEL", model.path);
if (!model.path_mmproj.empty()) {
preset.set_option(*this, "LLAMA_ARG_MMPROJ", model.path_mmproj);
}
out[preset.name] = preset;
}
return out;
}
common_preset common_preset_context::load_from_args(int argc, char ** argv) const {
common_preset preset;
preset.name = COMMON_PRESET_DEFAULT_NAME;
bool ok = common_params_to_map(argc, argv, ctx_params.ex, preset.options);
if (!ok) {
throw std::runtime_error("failed to parse CLI arguments into preset");
}
return preset;
}
common_presets common_preset_context::cascade(const common_presets & base, const common_presets & added) const {
common_presets out = base; // copy
for (const auto & [name, preset_added] : added) {
if (out.find(name) != out.end()) {
// if exists, merge
common_preset & target = out[name];
target.merge(preset_added);
} else {
// otherwise, add directly
out[name] = preset_added;
}
}
return out;
}
common_presets common_preset_context::cascade(const common_preset & base, const common_presets & presets) const {
common_presets out;
for (const auto & [name, preset] : presets) {
common_preset tmp = base; // copy
tmp.name = name;
tmp.merge(preset);
out[name] = std::move(tmp);
}
return out;
}

View File

@ -13,20 +13,62 @@
constexpr const char * COMMON_PRESET_DEFAULT_NAME = "default"; constexpr const char * COMMON_PRESET_DEFAULT_NAME = "default";
struct common_preset_context;
struct common_preset { struct common_preset {
std::string name; std::string name;
// TODO: support repeated args in the future
// options are stored as common_arg to string mapping, representing CLI arg and its value
std::map<common_arg, std::string> options; std::map<common_arg, std::string> options;
// convert preset to CLI argument list // convert preset to CLI argument list
std::vector<std::string> to_args() const; std::vector<std::string> to_args(const std::string & bin_path = "") const;
// convert preset to INI format string // convert preset to INI format string
std::string to_ini() const; std::string to_ini() const;
// TODO: maybe implement to_env() if needed // TODO: maybe implement to_env() if needed
// modify preset options where argument is identified by its env variable
void set_option(const common_preset_context & ctx, const std::string & env, const std::string & value);
// unset option by its env variable
void unset_option(const std::string & env);
// get option value by its env variable, return false if not found
bool get_option(const std::string & env, std::string & value) const;
// merge another preset into this one, overwriting existing options
void merge(const common_preset & other);
}; };
// interface for multiple presets in one file // interface for multiple presets in one file
using common_presets = std::map<std::string, common_preset>; using common_presets = std::map<std::string, common_preset>;
common_presets common_presets_load(const std::string & path, common_params_context & ctx_params);
// context for loading and editing presets
struct common_preset_context {
common_params default_params; // unused for now
common_params_context ctx_params;
std::map<std::string, common_arg> key_to_opt;
common_preset_context(llama_example ex);
// load presets from INI file
common_presets load_from_ini(const std::string & path, common_preset & global) const;
// generate presets from cached models
common_presets load_from_cache() const;
// generate presets from local models directory
// for the directory structure, see "Using multiple models" in server/README.md
common_presets load_from_models_dir(const std::string & models_dir) const;
// generate one preset from CLI arguments
common_preset load_from_args(int argc, char ** argv) const;
// cascade multiple presets if exist on both: base < added
// if preset does not exist in base, it will be added without modification
common_presets cascade(const common_presets & base, const common_presets & added) const;
// apply presets over a base preset (same idea as CSS cascading)
common_presets cascade(const common_preset & base, const common_presets & presets) const;
};

View File

@ -712,6 +712,9 @@ class ModelBase:
if "thinker_config" in config: if "thinker_config" in config:
# rename for Qwen2.5-Omni # rename for Qwen2.5-Omni
config["text_config"] = config["thinker_config"]["text_config"] config["text_config"] = config["thinker_config"]["text_config"]
if "lfm" in config:
# rename for LFM2-Audio
config["text_config"] = config["lfm"]
return config return config
@classmethod @classmethod
@ -9713,12 +9716,12 @@ class LFM2Model(TextModel):
self._add_feed_forward_length() self._add_feed_forward_length()
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name if self._is_vision_tensor(name) or self._is_audio_tensor(name):
if is_vision_tensor: # skip multimodal tensors
# skip vision tensors
return [] return []
name = name.replace("language_model.", "") name = name.replace("language_model.", "") # vision
name = name.replace("lfm.", "model.") # audio
# conv op requires 2d tensor # conv op requires 2d tensor
if 'conv.conv' in name: if 'conv.conv' in name:
@ -9726,6 +9729,12 @@ class LFM2Model(TextModel):
return [(self.map_tensor_name(name), data_torch)] return [(self.map_tensor_name(name), data_torch)]
def _is_vision_tensor(self, name: str) -> bool:
return "vision_tower" in name or "multi_modal_projector" in name
def _is_audio_tensor(self, name: str):
return any(p in name for p in ["audio", "codebook", "conformer", "depth_embedding", "depthformer", "depth_linear"])
@ModelBase.register("Lfm2MoeForCausalLM") @ModelBase.register("Lfm2MoeForCausalLM")
class LFM2MoeModel(TextModel): class LFM2MoeModel(TextModel):
@ -9831,6 +9840,81 @@ class LFM2VLModel(MmprojModel):
return [] # skip other tensors return [] # skip other tensors
@ModelBase.register("Lfm2AudioForConditionalGeneration")
class LFM2AudioModel(MmprojModel):
has_vision_encoder = False
has_audio_encoder = True
model_name = "Lfm2AudioEncoder"
_batch_norm_tensors: list[dict[str, Tensor]] | None = None
def get_audio_config(self) -> dict[str, Any] | None:
return self.global_config.get("encoder")
def set_gguf_parameters(self):
assert self.hparams_audio is not None
self.hparams_audio["hidden_size"] = self.hparams_audio["d_model"]
self.hparams_audio["intermediate_size"] = self.hparams_audio["d_model"]
self.hparams_audio["num_attention_heads"] = self.hparams_audio["n_heads"]
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LFM2A)
self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["feat_in"])
self.gguf_writer.add_audio_attention_layernorm_eps(1e-5)
def tensor_force_quant(self, name, new_name, bid, n_dims):
if ".conv" in name and ".weight" in name:
return gguf.GGMLQuantizationType.F32
return super().tensor_force_quant(name, new_name, bid, n_dims)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# skip language model tensors
if name.startswith("lfm."):
return []
# for training only
if any(p in name for p in ["audio_loss_weight"]):
return []
# for audio output
if any(p in name for p in ["codebook_offsets", "depth_embeddings", "depth_linear", "depthformer"]):
return []
# fold running_mean, running_var and eps into weight and bias for batch_norm
if "batch_norm" in name:
if self._batch_norm_tensors is None:
self._batch_norm_tensors = [{} for _ in range(self.block_count)]
assert bid is not None
self._batch_norm_tensors[bid][name] = data_torch
if len(self._batch_norm_tensors[bid]) < 5:
return []
weight = self._batch_norm_tensors[bid][f"conformer.layers.{bid}.conv.batch_norm.weight"]
bias = self._batch_norm_tensors[bid][f"conformer.layers.{bid}.conv.batch_norm.bias"]
running_mean = self._batch_norm_tensors[bid][f"conformer.layers.{bid}.conv.batch_norm.running_mean"]
running_var = self._batch_norm_tensors[bid][f"conformer.layers.{bid}.conv.batch_norm.running_var"]
eps = 1e-5 # default value
a = weight / torch.sqrt(running_var + eps)
b = bias - running_mean * a
return [
(self.map_tensor_name(f"conformer.layers.{bid}.conv.batch_norm.weight"), a),
(self.map_tensor_name(f"conformer.layers.{bid}.conv.batch_norm.bias"), b),
]
# reshape conv weights
if name.startswith("conformer.pre_encode.conv.") and name.endswith(".bias"):
data_torch = data_torch[:, None, None]
if "conv.depthwise_conv" in name and name.endswith(".weight"):
assert data_torch.shape[1] == 1
data_torch = data_torch.reshape(data_torch.shape[0], data_torch.shape[2])
if "conv.pointwise_conv" in name and name.endswith(".weight"):
assert data_torch.shape[2] == 1
data_torch = data_torch.reshape(data_torch.shape[0], data_torch.shape[1])
return [(self.map_tensor_name(name), data_torch)]
@ModelBase.register("SmallThinkerForCausalLM") @ModelBase.register("SmallThinkerForCausalLM")
class SmallThinkerModel(TextModel): class SmallThinkerModel(TextModel):
model_arch = gguf.MODEL_ARCH.SMALLTHINKER model_arch = gguf.MODEL_ARCH.SMALLTHINKER

View File

@ -1,27 +1,27 @@
# Android # Android
## Build with Android Studio ## Build GUI binding using Android Studio
Import the `examples/llama.android` directory into Android Studio, then perform a Gradle sync and build the project. Import the `examples/llama.android` directory into Android Studio, then perform a Gradle sync and build the project.
![Project imported into Android Studio](./android/imported-into-android-studio.png) ![Project imported into Android Studio](./android/imported-into-android-studio.jpg)
This Android binding supports hardware acceleration up to `SME2` for **Arm** and `AMX` for **x86-64** CPUs on Android and ChromeOS devices. This Android binding supports hardware acceleration up to `SME2` for **Arm** and `AMX` for **x86-64** CPUs on Android and ChromeOS devices.
It automatically detects the host's hardware to load compatible kernels. As a result, it runs seamlessly on both the latest premium devices and older devices that may lack modern CPU features or have limited RAM, without requiring any manual configuration. It automatically detects the host's hardware to load compatible kernels. As a result, it runs seamlessly on both the latest premium devices and older devices that may lack modern CPU features or have limited RAM, without requiring any manual configuration.
A minimal Android app frontend is included to showcase the bindings core functionalities: A minimal Android app frontend is included to showcase the bindings core functionalities:
1. **Parse GGUF metadata** via `GgufMetadataReader` from either a `ContentResolver` provided `Uri` or a local `File`. 1. **Parse GGUF metadata** via `GgufMetadataReader` from either a `ContentResolver` provided `Uri` from shared storage, or a local `File` from your app's private storage.
2. **Obtain a `TierDetection` or `InferenceEngine`** instance through the high-level facade APIs. 2. **Obtain a `InferenceEngine`** instance through the `AiChat` facade and load your selected model via its app-private file path.
3. **Send a raw user prompt** for automatic template formatting, prefill, and decoding. Then collect the generated tokens in a Kotlin `Flow`. 3. **Send a raw user prompt** for automatic template formatting, prefill, and batch decoding. Then collect the generated tokens in a Kotlin `Flow`.
For a production-ready experience that leverages advanced features such as system prompts and benchmarks, check out [Arm AI Chat](https://play.google.com/store/apps/details?id=com.arm.aichat) on Google Play. For a production-ready experience that leverages advanced features such as system prompts and benchmarks, plus friendly UI features such as model management and Arm feature visualizer, check out [Arm AI Chat](https://play.google.com/store/apps/details?id=com.arm.aichat) on Google Play.
This project is made possible through a collaborative effort by Arm's **CT-ML**, **CE-ML** and **STE** groups: This project is made possible through a collaborative effort by Arm's **CT-ML**, **CE-ML** and **STE** groups:
| ![Home screen](./android/arm-ai-chat-home-screen.png) | ![System prompt](./android/system-prompt-setup.png) | !["Haiku"](./android/chat-with-system-prompt-haiku.png) | | ![Home screen](https://naco-siren.github.io/ai-chat/policy/index/1-llm-starter-pack.png) | ![System prompt](https://naco-siren.github.io/ai-chat/policy/index/5-system-prompt.png) | !["Haiku"](https://naco-siren.github.io/ai-chat/policy/index/4-metrics.png) |
|:------------------------------------------------------:|:----------------------------------------------------:|:--------------------------------------------------------:| |:------------------------------------------------------:|:----------------------------------------------------:|:--------------------------------------------------------:|
| Home screen | System prompt | "Haiku" | | Home screen | System prompt | "Haiku" |
## Build on Android using Termux ## Build CLI on Android using Termux
[Termux](https://termux.dev/en/) is an Android terminal emulator and Linux environment app (no root required). As of writing, Termux is available experimentally in the Google Play Store; otherwise, it may be obtained directly from the project repo or on F-Droid. [Termux](https://termux.dev/en/) is an Android terminal emulator and Linux environment app (no root required). As of writing, Termux is available experimentally in the Google Play Store; otherwise, it may be obtained directly from the project repo or on F-Droid.
@ -52,7 +52,7 @@ To see what it might look like visually, here's an old demo of an interactive se
https://user-images.githubusercontent.com/271616/225014776-1d567049-ad71-4ef2-b050-55b0b3b9274c.mp4 https://user-images.githubusercontent.com/271616/225014776-1d567049-ad71-4ef2-b050-55b0b3b9274c.mp4
## Cross-compile using Android NDK ## Cross-compile CLI using Android NDK
It's possible to build `llama.cpp` for Android on your host system via CMake and the Android NDK. If you are interested in this path, ensure you already have an environment prepared to cross-compile programs for Android (i.e., install the Android SDK). Note that, unlike desktop environments, the Android environment ships with a limited set of native libraries, and so only those libraries are available to CMake when building with the Android NDK (see: https://developer.android.com/ndk/guides/stable_apis.) It's possible to build `llama.cpp` for Android on your host system via CMake and the Android NDK. If you are interested in this path, ensure you already have an environment prepared to cross-compile programs for Android (i.e., install the Android SDK). Note that, unlike desktop environments, the Android environment ships with a limited set of native libraries, and so only those libraries are available to CMake when building with the Android NDK (see: https://developer.android.com/ndk/guides/stable_apis.)
Once you're ready and have cloned `llama.cpp`, invoke the following in the project directory: Once you're ready and have cloned `llama.cpp`, invoke the following in the project directory:

Binary file not shown.

After

Width:  |  Height:  |  Size: 479 KiB

View File

@ -22,6 +22,7 @@
"GGML_LLAMAFILE": "OFF", "GGML_LLAMAFILE": "OFF",
"GGML_OPENCL": "ON", "GGML_OPENCL": "ON",
"GGML_HEXAGON": "ON", "GGML_HEXAGON": "ON",
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
"LLAMA_CURL": "OFF" "LLAMA_CURL": "OFF"
} }
}, },
@ -36,6 +37,7 @@
"GGML_LLAMAFILE": "OFF", "GGML_LLAMAFILE": "OFF",
"GGML_OPENCL": "ON", "GGML_OPENCL": "ON",
"GGML_HEXAGON": "ON", "GGML_HEXAGON": "ON",
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
"LLAMA_CURL": "OFF" "LLAMA_CURL": "OFF"
} }
}, },

View File

@ -1,55 +1,57 @@
<?xml version="1.0" encoding="utf-8"?> <?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android" <androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto" xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools" xmlns:tools="http://schemas.android.com/tools"
android:id="@+id/main" android:id="@+id/main"
android:layout_height="match_parent" android:layout_height="match_parent"
android:layout_width="match_parent"> android:layout_width="match_parent">
<LinearLayout <LinearLayout
android:fitsSystemWindows="true" android:fitsSystemWindows="true"
android:layout_width="match_parent" android:layout_width="match_parent"
android:layout_height="match_parent" android:layout_height="match_parent"
android:orientation="vertical" android:orientation="vertical"
android:layout_marginEnd="4dp"
tools:context=".MainActivity"> tools:context=".MainActivity">
<FrameLayout <ScrollView
android:layout_width="match_parent" android:layout_width="match_parent"
android:layout_height="0dp" android:layout_height="0dp"
android:layout_weight="1"> android:layout_weight="1"
android:fadeScrollbars="false">
<ScrollView <TextView
android:id="@+id/gguf"
android:layout_width="match_parent" android:layout_width="match_parent"
android:layout_height="wrap_content" android:layout_height="wrap_content"
android:fadeScrollbars="false"> android:layout_margin="16dp"
android:text="Selected GGUF model's metadata will show here."
style="@style/TextAppearance.MaterialComponents.Body2" />
<TextView </ScrollView>
android:id="@+id/gguf"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_margin="16dp"
android:text="Selected GGUF model's metadata will show here."
style="@style/TextAppearance.MaterialComponents.Body2"
android:maxLines="100" />
</ScrollView> <com.google.android.material.divider.MaterialDivider
android:layout_width="match_parent"
</FrameLayout> android:layout_height="2dp"
android:layout_marginHorizontal="16dp"
android:layout_marginVertical="8dp" />
<androidx.recyclerview.widget.RecyclerView <androidx.recyclerview.widget.RecyclerView
android:id="@+id/messages" android:id="@+id/messages"
android:layout_width="match_parent" android:layout_width="match_parent"
android:layout_height="0dp" android:layout_height="0dp"
android:layout_weight="4" android:layout_weight="4"
android:padding="16dp"
android:fadeScrollbars="false" android:fadeScrollbars="false"
android:scrollbars="vertical"
app:reverseLayout="true" app:reverseLayout="true"
tools:listitem="@layout/item_message_assistant"/> tools:listitem="@layout/item_message_assistant"/>
<LinearLayout <LinearLayout
android:layout_width="match_parent" android:layout_width="match_parent"
android:layout_height="wrap_content" android:layout_height="wrap_content"
android:orientation="horizontal"> android:orientation="horizontal"
android:paddingStart="16dp"
android:paddingEnd="4dp">
<EditText <EditText
android:id="@+id/user_input" android:id="@+id/user_input"
@ -67,7 +69,7 @@
style="@style/Widget.Material3.FloatingActionButton.Primary" style="@style/Widget.Material3.FloatingActionButton.Primary"
android:layout_width="wrap_content" android:layout_width="wrap_content"
android:layout_height="wrap_content" android:layout_height="wrap_content"
android:layout_margin="8dp" android:layout_margin="12dp"
android:src="@drawable/outline_folder_open_24" /> android:src="@drawable/outline_folder_open_24" />
</LinearLayout> </LinearLayout>

View File

@ -2,7 +2,8 @@
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android" <LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent" android:layout_width="match_parent"
android:layout_height="wrap_content" android:layout_height="wrap_content"
android:padding="8dp" android:layout_marginHorizontal="16dp"
android:layout_marginVertical="8dp"
android:gravity="start"> android:gravity="start">
<TextView <TextView

View File

@ -2,7 +2,8 @@
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android" <LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent" android:layout_width="match_parent"
android:layout_height="wrap_content" android:layout_height="wrap_content"
android:padding="8dp" android:layout_marginHorizontal="16dp"
android:layout_marginVertical="8dp"
android:gravity="end"> android:gravity="end">
<TextView <TextView

View File

@ -2,135 +2,22 @@
import argparse import argparse
import os import os
import sys
import importlib import importlib
from pathlib import Path from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig
import torch import torch
import numpy as np import numpy as np
from utils.common import debug_hook
### If you want to dump RoPE activations, apply this monkey patch to the model
### class from Transformers that you are running (replace apertus.modeling_apertus
### with the proper package and class for your model
### === START ROPE DEBUG ===
# from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb
# orig_rope = apply_rotary_pos_emb
# torch.set_printoptions(threshold=float('inf'))
# torch.set_printoptions(precision=6, sci_mode=False)
# def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
# # log inputs
# summarize(q, "RoPE.q_in")
# summarize(k, "RoPE.k_in")
# # call original
# q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
# # log outputs
# summarize(q_out, "RoPE.q_out")
# summarize(k_out, "RoPE.k_out")
# return q_out, k_out
# # Patch it
# import transformers.models.apertus.modeling_apertus as apertus_mod # noqa: E402
# apertus_mod.apply_rotary_pos_emb = debug_rope
### == END ROPE DEBUG ===
def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
"""
Print a tensor in llama.cpp debug style.
Supports:
- 2D tensors (seq, hidden)
- 3D tensors (batch, seq, hidden)
- 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
Shows first and last max_vals of each vector per sequence position.
"""
t = tensor.detach().to(torch.float32).cpu()
# Determine dimensions
if t.ndim == 3:
_, s, _ = t.shape
elif t.ndim == 2:
_, s = 1, t.shape[0]
t = t.unsqueeze(0)
elif t.ndim == 4:
_, s, _, _ = t.shape
else:
print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
return
ten_shape = t.shape
print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
print(" [")
print(" [")
# Determine indices for first and last sequences
first_indices = list(range(min(s, max_seq)))
last_indices = list(range(max(0, s - max_seq), s))
# Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
# Combine indices
if has_overlap:
# If there's overlap, just use the combined unique indices
indices = sorted(list(set(first_indices + last_indices)))
separator_index = None
else:
# If no overlap, we'll add a separator between first and last sequences
indices = first_indices + last_indices
separator_index = len(first_indices)
for i, si in enumerate(indices):
# Add separator if needed
if separator_index is not None and i == separator_index:
print(" ...")
# Extract appropriate slice
vec = t[0, si]
if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
flat = vec.flatten().tolist()
else: # 2D or 3D case
flat = vec.tolist()
# First and last slices
first = flat[:max_vals]
last = flat[-max_vals:] if len(flat) >= max_vals else flat
first_str = ", ".join(f"{v:12.4f}" for v in first)
last_str = ", ".join(f"{v:12.4f}" for v in last)
print(f" [{first_str}, ..., {last_str}]")
print(" ],")
print(" ]")
print(f" sum = {t.sum().item():.6f}\n")
def debug_hook(name):
def fn(_m, input, output):
if isinstance(input, torch.Tensor):
summarize(input, name + "_in")
elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
summarize(input[0], name + "_in")
if isinstance(output, torch.Tensor):
summarize(output, name + "_out")
elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
summarize(output[0], name + "_out")
return fn
unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
parser = argparse.ArgumentParser(description="Process model with specified path") parser = argparse.ArgumentParser(description="Process model with specified path")
parser.add_argument("--model-path", "-m", help="Path to the model") parser.add_argument("--model-path", "-m", help="Path to the model")
parser.add_argument("--prompt-file", "-f", help="Optional prompt file", required=False) parser.add_argument("--prompt-file", "-f", help="Optional prompt file", required=False)
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose debug output")
args = parser.parse_args() args = parser.parse_args()
model_path = os.environ.get("MODEL_PATH", args.model_path) model_path = os.environ.get("MODEL_PATH", args.model_path)
@ -139,6 +26,12 @@ if model_path is None:
"Model path must be specified either via --model-path argument or MODEL_PATH environment variable" "Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
) )
### If you want to dump RoPE activations, uncomment the following lines:
### === START ROPE DEBUG ===
# from utils.common import setup_rope_debug
# setup_rope_debug("transformers.models.apertus.modeling_apertus")
### == END ROPE DEBUG ===
print("Loading model and tokenizer using AutoTokenizer:", model_path) print("Loading model and tokenizer using AutoTokenizer:", model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
@ -156,6 +49,7 @@ print("Number of layers: ", config.num_hidden_layers)
print("BOS token id: ", config.bos_token_id) print("BOS token id: ", config.bos_token_id)
print("EOS token id: ", config.eos_token_id) print("EOS token id: ", config.eos_token_id)
unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
if unreleased_model_name: if unreleased_model_name:
model_name_lower = unreleased_model_name.lower() model_name_lower = unreleased_model_name.lower()
unreleased_module_path = ( unreleased_module_path = (
@ -184,9 +78,10 @@ else:
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
) )
for name, module in model.named_modules(): if args.verbose:
if len(list(module.children())) == 0: # only leaf modules for name, module in model.named_modules():
module.register_forward_hook(debug_hook(name)) if len(list(module.children())) == 0: # only leaf modules
module.register_forward_hook(debug_hook(name))
model_name = os.path.basename(model_path) model_name = os.path.basename(model_path)
# Printing the Model class to allow for easier debugging. This can be useful # Printing the Model class to allow for easier debugging. This can be useful

View File

@ -2,6 +2,8 @@
import os import os
import sys import sys
import torch
def get_model_name_from_env_path(env_path_name): def get_model_name_from_env_path(env_path_name):
model_path = os.getenv(env_path_name) model_path = os.getenv(env_path_name)
@ -18,3 +20,131 @@ def get_model_name_from_env_path(env_path_name):
name = name[:-5] name = name[:-5]
return name return name
def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
"""
Print a tensor in llama.cpp debug style.
Supports:
- 2D tensors (seq, hidden)
- 3D tensors (batch, seq, hidden)
- 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
Shows first and last max_vals of each vector per sequence position.
"""
t = tensor.detach().to(torch.float32).cpu()
# Determine dimensions
if t.ndim == 3:
_, s, _ = t.shape
elif t.ndim == 2:
_, s = 1, t.shape[0]
t = t.unsqueeze(0)
elif t.ndim == 4:
_, s, _, _ = t.shape
else:
print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
return
ten_shape = t.shape
print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
print(" [")
print(" [")
# Determine indices for first and last sequences
first_indices = list(range(min(s, max_seq)))
last_indices = list(range(max(0, s - max_seq), s))
# Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
# Combine indices
if has_overlap:
# If there's overlap, just use the combined unique indices
indices = sorted(list(set(first_indices + last_indices)))
separator_index = None
else:
# If no overlap, we'll add a separator between first and last sequences
indices = first_indices + last_indices
separator_index = len(first_indices)
for i, si in enumerate(indices):
# Add separator if needed
if separator_index is not None and i == separator_index:
print(" ...")
# Extract appropriate slice
vec = t[0, si]
if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
flat = vec.flatten().tolist()
else: # 2D or 3D case
flat = vec.tolist()
# First and last slices
first = flat[:max_vals]
last = flat[-max_vals:] if len(flat) >= max_vals else flat
first_str = ", ".join(f"{v:12.4f}" for v in first)
last_str = ", ".join(f"{v:12.4f}" for v in last)
print(f" [{first_str}, ..., {last_str}]")
print(" ],")
print(" ]")
print(f" sum = {t.sum().item():.6f}\n")
def debug_hook(name):
def fn(_m, input, output):
if isinstance(input, torch.Tensor):
summarize(input, name + "_in")
elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
summarize(input[0], name + "_in")
if isinstance(output, torch.Tensor):
summarize(output, name + "_out")
elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
summarize(output[0], name + "_out")
return fn
def setup_rope_debug(model_module_path: str, function_name: str = "apply_rotary_pos_emb"):
"""
Apply monkey patch to dump RoPE activations for debugging.
Args:
model_module_path: Path to the model module (e.g., "transformers.models.apertus.modeling_apertus")
function_name: Name of the RoPE function to patch (default: "apply_rotary_pos_emb")
Example:
from utils.common import setup_rope_debug
setup_rope_debug("transformers.models.apertus.modeling_apertus")
"""
import importlib
# Import the module and get the original function
module = importlib.import_module(model_module_path)
orig_rope = getattr(module, function_name)
# Set torch print options for better debugging
torch.set_printoptions(threshold=float('inf'))
torch.set_printoptions(precision=6, sci_mode=False)
def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
# log inputs
summarize(q, "RoPE.q_in")
summarize(k, "RoPE.k_in")
# call original
q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
# log outputs
summarize(q_out, "RoPE.q_out")
summarize(k_out, "RoPE.k_out")
return q_out, k_out
# Patch it
setattr(module, function_name, debug_rope)
print(f"RoPE debug patching applied to {model_module_path}.{function_name}")

View File

@ -254,6 +254,7 @@ set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING
"gmml: OpenCL API version to target") "gmml: OpenCL API version to target")
option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF) option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF)
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml: quantize group size (32, 64, or 128)")
# toolchain for vulkan-shaders-gen # toolchain for vulkan-shaders-gen
set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen") set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")

View File

@ -102,31 +102,25 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
const int threads = 128; const int threads = 128;
GGML_ASSERT(nr % threads == 0); GGML_ASSERT(nr % threads == 0);
if (n_t <= 32) { auto launch_kernel = [&](auto NC) {
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); constexpr int kNC = decltype(NC)::value;
if (nc == 4) { if (n_t <= 32) {
ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
dst, dst_nb0, dst_nb1, dst_nb2, n_t); ssm_conv_f32<threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
} else if (nc == 3) { dst, dst_nb0, dst_nb1, dst_nb2, n_t);
ssm_conv_f32<threads, 3><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else { } else {
GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
}
} else {
if (nc == 4) {
const int64_t split_n_t = 32; const int64_t split_n_t = 32;
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
ssm_conv_long_token_f32<threads, 4, split_n_t><<<blocks, threads, 0, stream>>>( ssm_conv_long_token_f32<threads, kNC, split_n_t><<<blocks, threads, 0, stream>>>(
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else if (nc == 3) {
const int64_t split_n_t = 32;
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
ssm_conv_long_token_f32<threads, 3, split_n_t><<<blocks, threads, 0, stream>>>(
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else {
GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
} }
};
switch (nc) {
case 3: launch_kernel(std::integral_constant<int, 3>{}); break;
case 4: launch_kernel(std::integral_constant<int, 4>{}); break;
case 9: launch_kernel(std::integral_constant<int, 9>{}); break;
default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now.");
} }
} }

View File

@ -2,6 +2,7 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
include(ExternalProject) include(ExternalProject)
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF) option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)")
add_library(htp_iface OBJECT add_library(htp_iface OBJECT
${CMAKE_CURRENT_BINARY_DIR}/htp_iface_stub.c) ${CMAKE_CURRENT_BINARY_DIR}/htp_iface_stub.c)
@ -41,7 +42,8 @@ set(HTP_CMAKE_ARGS
-DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR} -DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR}
-DHEXAGON_SDK_ROOT=$ENV{HEXAGON_SDK_ROOT} -DHEXAGON_SDK_ROOT=$ENV{HEXAGON_SDK_ROOT}
-DHEXAGON_TOOLS_ROOT=$ENV{HEXAGON_TOOLS_ROOT} -DHEXAGON_TOOLS_ROOT=$ENV{HEXAGON_TOOLS_ROOT}
-DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}) -DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}
-DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
ExternalProject_Add(htp-v68 ExternalProject_Add(htp-v68
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON

View File

@ -31,7 +31,8 @@ add_library(${HTP_LIB} SHARED
) )
target_compile_definitions(${HTP_LIB} PRIVATE target_compile_definitions(${HTP_LIB} PRIVATE
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>) $<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
build_idl(htp_iface.idl ${HTP_LIB}) build_idl(htp_iface.idl ${HTP_LIB})

View File

@ -92,6 +92,18 @@ static const uint8_t __attribute__((aligned(128))) repl_1x_fp16[128] = {
0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
}; };
// vdelta control to replicate first fp16 value across all elements
static const uint8_t __attribute__((aligned(128))) repl_2x_fp16[128] = {
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
};
// vdelta control to expand first 32 e8m0 values into 32 uint32 elements // vdelta control to expand first 32 e8m0 values into 32 uint32 elements
static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = { static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00,
@ -1594,6 +1606,118 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0,
// *** dynamic quant // *** dynamic quant
static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
assert((unsigned long) x % 128 == 0);
assert((unsigned long) y_q % 128 == 0);
HVX_Vector * vx = (HVX_Vector *) x;
HVX_Vector zero = Q6_V_vsplat_R(0);
// Use reduce max fp32 to find max(abs(e)) first
HVX_Vector vmax0_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[0]));
HVX_Vector vmax1_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[1]));
HVX_Vector vmax2_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[2]));
HVX_Vector vmax3_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[3]));
// Load and convert into QF32
HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
// Convert to QF32
HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
// Combine and convert to fp16
HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf)));
// Convert into fp16
HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
// Replicate first fp16 scale across all lanes
HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_fp16;
vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16);
hvx_vec_store_u(y_d + 0, 2, vd01_hf);
HVX_Vector rotated_vd_hf = Q6_V_vror_VR(vd01_hf, 64);
hvx_vec_store_u(y_d + 2, 2, rotated_vd_hf);
hvx_vec_store_u(y_d + 4, 2, vd23_hf);
rotated_vd_hf = Q6_V_vror_VR(vd23_hf, 64);
hvx_vec_store_u(y_d + 6, 2, rotated_vd_hf);
// Divide input by the scale
HVX_Vector vd01_inv_hf = hvx_vec_inverse_fp16(vd01_hf);
HVX_Vector vd23_inv_hf = hvx_vec_inverse_fp16(vd23_hf);
vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
// Convert to int8
HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
*(HVX_Vector *) y_q = vx_i8;
}
static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
assert((unsigned long) x % 128 == 0);
assert((unsigned long) y_q % 128 == 0);
HVX_Vector * vx = (HVX_Vector *) x;
// Load and convert into QF32
HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
// Convert into fp16
HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
// Compute max and scale
HVX_Vector vmax01_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx01_hf));
HVX_Vector vmax23_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx23_hf));
// Replicate first fp16 scale across all lanes
HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_fp16;
vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16);
hvx_vec_store_u(y_d + 0, 4, vd01_hf);
hvx_vec_store_u(y_d + 4, 4, vd23_hf);
// Divide input by the scale
HVX_Vector vd01_inv_hf = hvx_vec_inverse_fp16(vd01_hf);
HVX_Vector vd23_inv_hf = hvx_vec_inverse_fp16(vd23_hf);
vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
// Convert to int8
HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
*(HVX_Vector *) y_q = vx_i8;
}
static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
assert((unsigned long) x % 128 == 0); assert((unsigned long) x % 128 == 0);
assert((unsigned long) y_q % 128 == 0); assert((unsigned long) y_q % 128 == 0);
@ -1655,10 +1779,24 @@ static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, u
uint8_t * restrict t_d = (uint8_t *) x; uint8_t * restrict t_d = (uint8_t *) x;
for (uint32_t i = 0; i < nb; i++) { for (uint32_t i = 0; i < nb; i++) {
#if FP32_QUANTIZE_GROUP_SIZE == 32
quantize_block_fp32_q8x1(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2,
t_d + (i * 2 + 0) * dblk_size / 2);
quantize_block_fp32_q8x1(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2,
t_d + (i * 2 + 1) * dblk_size / 2);
#elif FP32_QUANTIZE_GROUP_SIZE == 64
quantize_block_fp32_q8x2(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2,
t_d + (i * 2 + 0) * dblk_size / 2);
quantize_block_fp32_q8x2(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2,
t_d + (i * 2 + 1) * dblk_size / 2);
#elif FP32_QUANTIZE_GROUP_SIZE == 128
quantize_block_fp32_q8x4(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2, quantize_block_fp32_q8x4(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2,
t_d + (i * 2 + 0) * dblk_size / 2); t_d + (i * 2 + 0) * dblk_size / 2);
quantize_block_fp32_q8x4(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2, quantize_block_fp32_q8x4(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2,
t_d + (i * 2 + 1) * dblk_size / 2); t_d + (i * 2 + 1) * dblk_size / 2);
#else
#error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128"
#endif
} }
// now copy the scales into final location // now copy the scales into final location
@ -1671,6 +1809,7 @@ static void quantize_fp32_q8x4x2(const struct htp_tensor * src,
uint32_t nth, uint32_t nth,
uint32_t ith, uint32_t ith,
uint32_t nrows_per_thread) { uint32_t nrows_per_thread) {
uint64_t t1 = HAP_perf_get_qtimer_count(); uint64_t t1 = HAP_perf_get_qtimer_count();
const uint32_t ne0 = src->ne[0]; const uint32_t ne0 = src->ne[0];

View File

@ -1527,6 +1527,8 @@ private:
#endif // GGML_VULKAN_MEMORY_DEBUG #endif // GGML_VULKAN_MEMORY_DEBUG
static bool vk_perf_logger_enabled = false; static bool vk_perf_logger_enabled = false;
static bool vk_perf_logger_concurrent = false;
static bool vk_enable_sync_logger = false;
// number of calls between perf logger prints // number of calls between perf logger prints
static uint32_t vk_perf_logger_frequency = 1; static uint32_t vk_perf_logger_frequency = 1;
@ -1577,14 +1579,14 @@ class vk_perf_logger {
flops.clear(); flops.clear();
} }
void log_timing(const ggml_tensor * node, const char *fusion_name, uint64_t time) { std::string get_node_fusion_name(const ggml_tensor * node, const char *fusion_name, uint64_t *n_flops) {
*n_flops = 0;
std::string fusion_str; std::string fusion_str;
if (fusion_name) { if (fusion_name) {
fusion_str = fusion_name + std::string(" "); fusion_str = fusion_name + std::string(" ");
} }
if (node->op == GGML_OP_UNARY) { if (node->op == GGML_OP_UNARY) {
timings[fusion_str + ggml_unary_op_name(ggml_get_unary_op(node))].push_back(time); return fusion_str + ggml_unary_op_name(ggml_get_unary_op(node));
return;
} }
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
const uint64_t m = node->ne[0]; const uint64_t m = node->ne[0];
@ -1606,9 +1608,8 @@ class vk_perf_logger {
name += " batch=" + std::to_string(batch); name += " batch=" + std::to_string(batch);
} }
name = fusion_str + name; name = fusion_str + name;
timings[name].push_back(time); *n_flops = m * n * (k + (k - 1)) * batch;
flops[name].push_back(m * n * (k + (k - 1)) * batch); return name;
return;
} }
if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) { if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) {
std::string name = ggml_op_name(node->op); std::string name = ggml_op_name(node->op);
@ -1624,20 +1625,17 @@ class vk_perf_logger {
uint64_t size_M = Cout; uint64_t size_M = Cout;
uint64_t size_K = Cin * KW * KH; uint64_t size_K = Cin * KW * KH;
uint64_t size_N = N * OW * OH; uint64_t size_N = N * OW * OH;
uint64_t n_flops = size_M * size_N * (size_K + (size_K - 1)); *n_flops = size_M * size_N * (size_K + (size_K - 1));
name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) + name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
", N=N*OW*OH=" + std::to_string(size_N); ", N=N*OW*OH=" + std::to_string(size_N);
name = fusion_str + name; name = fusion_str + name;
flops[name].push_back(n_flops); return name;
timings[name].push_back(time);
return;
} }
if (node->op == GGML_OP_RMS_NORM) { if (node->op == GGML_OP_RMS_NORM) {
std::string name = ggml_op_name(node->op); std::string name = ggml_op_name(node->op);
name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")"; name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
name = fusion_str + name; name = fusion_str + name;
timings[name].push_back(time); return name;
return;
} }
if (node->op == GGML_OP_FLASH_ATTN_EXT) { if (node->op == GGML_OP_FLASH_ATTN_EXT) {
const ggml_tensor * dst = node; const ggml_tensor * dst = node;
@ -1653,8 +1651,7 @@ class vk_perf_logger {
" k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " << " k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " <<
" v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " << " v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " <<
" m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")"; " m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")";
timings[name.str()].push_back(time); return name.str();
return;
} }
if (node->op == GGML_OP_TOP_K) { if (node->op == GGML_OP_TOP_K) {
std::stringstream name; std::stringstream name;
@ -1662,11 +1659,38 @@ class vk_perf_logger {
name << ggml_op_name(node->op) << name << ggml_op_name(node->op) <<
" K=" << node->ne[0] << " K=" << node->ne[0] <<
" (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")"; " (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")";
timings[name.str()].push_back(time); return name.str();
return;
} }
timings[fusion_str + ggml_op_name(node->op)].push_back(time); return fusion_str + ggml_op_name(node->op);
} }
void log_timing(const ggml_tensor * node, const char *fusion_name, uint64_t time) {
uint64_t n_flops;
std::string name = get_node_fusion_name(node, fusion_name, &n_flops);
if (n_flops) {
flops[name].push_back(n_flops);
}
timings[name].push_back(time);
}
void log_timing(const std::vector<ggml_tensor *> &nodes, const std::vector<const char *> &names, uint64_t time) {
uint64_t total_flops = 0;
std::string name;
for (size_t n = 0; n < nodes.size(); ++n) {
uint64_t n_flops = 0;
name += get_node_fusion_name(nodes[n], names[n], &n_flops);
total_flops += n_flops;
if (n != nodes.size() - 1) {
name += ", ";
}
}
if (total_flops) {
flops[name].push_back(total_flops);
}
timings[name].push_back(time);
}
private: private:
std::map<std::string, std::vector<uint64_t>> timings; std::map<std::string, std::vector<uint64_t>> timings;
std::map<std::string, std::vector<uint64_t>> flops; std::map<std::string, std::vector<uint64_t>> flops;
@ -1729,7 +1753,9 @@ struct ggml_backend_vk_context {
std::unique_ptr<vk_perf_logger> perf_logger; std::unique_ptr<vk_perf_logger> perf_logger;
vk::QueryPool query_pool; vk::QueryPool query_pool;
std::vector<const char *> query_fusion_names; std::vector<const char *> query_fusion_names;
std::vector<int> query_fusion_node_count;
std::vector<ggml_tensor *> query_nodes; std::vector<ggml_tensor *> query_nodes;
std::vector<int> query_node_idx;
int32_t num_queries {}; int32_t num_queries {};
int32_t query_idx {}; int32_t query_idx {};
}; };
@ -5194,6 +5220,8 @@ static void ggml_vk_instance_init() {
} }
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr; vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
vk_perf_logger_concurrent = getenv("GGML_VK_PERF_LOGGER_CONCURRENT") != nullptr;
vk_enable_sync_logger = getenv("GGML_VK_SYNC_LOGGER") != nullptr;
const char* GGML_VK_PERF_LOGGER_FREQUENCY = getenv("GGML_VK_PERF_LOGGER_FREQUENCY"); const char* GGML_VK_PERF_LOGGER_FREQUENCY = getenv("GGML_VK_PERF_LOGGER_FREQUENCY");
if (GGML_VK_PERF_LOGGER_FREQUENCY != nullptr) { if (GGML_VK_PERF_LOGGER_FREQUENCY != nullptr) {
@ -11820,15 +11848,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
} }
} }
#define ENABLE_SYNC_LOGGING 0
if (need_sync) { if (need_sync) {
#if ENABLE_SYNC_LOGGING if (vk_enable_sync_logger) {
std::cerr << "sync" << std::endl; std::cerr << "sync" << std::endl;
#endif }
ctx->unsynced_nodes_written.clear(); ctx->unsynced_nodes_written.clear();
ctx->unsynced_nodes_read.clear(); ctx->unsynced_nodes_read.clear();
ggml_vk_sync_buffers(ctx, compute_ctx); ggml_vk_sync_buffers(ctx, compute_ctx);
if (vk_perf_logger_enabled && vk_perf_logger_concurrent) {
ctx->query_node_idx[ctx->query_idx] = node_idx;
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
}
} }
// Add all fused nodes to the unsynchronized lists. // Add all fused nodes to the unsynchronized lists.
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
@ -11845,20 +11876,20 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
} }
} }
} }
#if ENABLE_SYNC_LOGGING if (vk_enable_sync_logger) {
for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
auto *n = cgraph->nodes[node_idx + i]; auto *n = cgraph->nodes[node_idx + i];
std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name; std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name;
if (n->op == GGML_OP_GLU) { if (n->op == GGML_OP_GLU) {
std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " "; std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
}
if (n->op == GGML_OP_ROPE) {
const int mode = ((const int32_t *) n->op_params)[2];
std::cerr << " rope mode: " << mode;
}
std::cerr << std::endl;
} }
if (n->op == GGML_OP_ROPE) {
const int mode = ((const int32_t *) n->op_params)[2];
std::cerr << " rope mode: " << mode;
}
std::cerr << std::endl;
} }
#endif
switch (node->op) { switch (node->op) {
case GGML_OP_REPEAT: case GGML_OP_REPEAT:
@ -13138,12 +13169,16 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->query_pool = ctx->device->device.createQueryPool(query_create_info); ctx->query_pool = ctx->device->device.createQueryPool(query_create_info);
ctx->num_queries = query_create_info.queryCount; ctx->num_queries = query_create_info.queryCount;
ctx->query_fusion_names.resize(ctx->num_queries); ctx->query_fusion_names.resize(ctx->num_queries);
ctx->query_fusion_node_count.resize(ctx->num_queries);
ctx->query_nodes.resize(ctx->num_queries); ctx->query_nodes.resize(ctx->num_queries);
ctx->query_node_idx.resize(ctx->num_queries);
} }
ctx->device->device.resetQueryPool(ctx->query_pool, 0, cgraph->n_nodes+1); ctx->device->device.resetQueryPool(ctx->query_pool, 0, cgraph->n_nodes+1);
std::fill(ctx->query_fusion_names.begin(), ctx->query_fusion_names.end(), nullptr); std::fill(ctx->query_fusion_names.begin(), ctx->query_fusion_names.end(), nullptr);
std::fill(ctx->query_fusion_node_count.begin(), ctx->query_fusion_node_count.end(), 0);
std::fill(ctx->query_nodes.begin(), ctx->query_nodes.end(), nullptr); std::fill(ctx->query_nodes.begin(), ctx->query_nodes.end(), nullptr);
std::fill(ctx->query_node_idx.begin(), ctx->query_node_idx.end(), 0);
GGML_ASSERT(ctx->compute_ctx.expired()); GGML_ASSERT(ctx->compute_ctx.expired());
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
@ -13272,9 +13307,16 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
} else { } else {
compute_ctx = ctx->compute_ctx.lock(); compute_ctx = ctx->compute_ctx.lock();
} }
ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i]; if (!vk_perf_logger_concurrent) {
ctx->query_fusion_names[ctx->query_idx] = fusion_string; // track a single node/fusion for the current query
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i];
ctx->query_fusion_names[ctx->query_idx] = fusion_string;
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
} else {
// track a fusion string and number of fused ops for the current node_idx
ctx->query_fusion_names[i] = fusion_string;
ctx->query_fusion_node_count[i] = ctx->num_additional_fused_ops;
}
} }
if (enqueued) { if (enqueued) {
@ -13316,12 +13358,32 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
// Get the results and pass them to the logger // Get the results and pass them to the logger
std::vector<uint64_t> timestamps(cgraph->n_nodes + 1); std::vector<uint64_t> timestamps(cgraph->n_nodes + 1);
VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->query_pool, 0, ctx->query_idx, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results"); VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->query_pool, 0, ctx->query_idx, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results");
for (int i = 1; i < ctx->query_idx; i++) { if (!vk_perf_logger_concurrent) {
auto node = ctx->query_nodes[i]; // Log each op separately
auto name = ctx->query_fusion_names[i]; for (int i = 1; i < ctx->query_idx; i++) {
ctx->perf_logger->log_timing(node, name, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod)); auto node = ctx->query_nodes[i];
auto name = ctx->query_fusion_names[i];
ctx->perf_logger->log_timing(node, name, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod));
}
} else {
// Log each group of nodes
int prev_node_idx = 0;
for (int i = 1; i < ctx->query_idx; i++) {
auto cur_node_idx = ctx->query_node_idx[i];
std::vector<ggml_tensor *> nodes;
std::vector<const char *> names;
for (int node_idx = prev_node_idx; node_idx < cur_node_idx; ++node_idx) {
if (ggml_op_is_empty(cgraph->nodes[node_idx]->op)) {
continue;
}
nodes.push_back(cgraph->nodes[node_idx]);
names.push_back(ctx->query_fusion_names[node_idx]);
node_idx += ctx->query_fusion_node_count[node_idx];
}
prev_node_idx = cur_node_idx;
ctx->perf_logger->log_timing(nodes, names, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod));
}
} }
ctx->perf_logger->print_timings(); ctx->perf_logger->print_timings();
} }

View File

@ -690,6 +690,8 @@ class MODEL_TENSOR(IntEnum):
V_TOK_EOI = auto() # cogvlm V_TOK_EOI = auto() # cogvlm
# audio (mtmd) # audio (mtmd)
A_ENC_EMBD_POS = auto() A_ENC_EMBD_POS = auto()
A_ENC_EMBD_NORM = auto()
A_ENC_EMBD_TO_LOGITS = auto()
A_ENC_CONV1D = auto() A_ENC_CONV1D = auto()
A_PRE_NORM = auto() A_PRE_NORM = auto()
A_POST_NORM = auto() A_POST_NORM = auto()
@ -700,8 +702,13 @@ class MODEL_TENSOR(IntEnum):
A_ENC_OUTPUT = auto() A_ENC_OUTPUT = auto()
A_ENC_OUTPUT_NORM = auto() A_ENC_OUTPUT_NORM = auto()
A_ENC_FFN_UP = auto() A_ENC_FFN_UP = auto()
A_ENC_FFN_NORM = auto()
A_ENC_FFN_GATE = auto() A_ENC_FFN_GATE = auto()
A_ENC_FFN_DOWN = auto() A_ENC_FFN_DOWN = auto()
A_ENC_FFN_UP_1 = auto()
A_ENC_FFN_NORM_1 = auto()
A_ENC_FFN_GATE_1 = auto()
A_ENC_FFN_DOWN_1 = auto()
A_MMPROJ = auto() A_MMPROJ = auto()
A_MMPROJ_FC = auto() A_MMPROJ_FC = auto()
A_MM_NORM_PRE = auto() A_MM_NORM_PRE = auto()
@ -713,6 +720,16 @@ class MODEL_TENSOR(IntEnum):
NEXTN_HNORM = auto() NEXTN_HNORM = auto()
NEXTN_SHARED_HEAD_HEAD = auto() NEXTN_SHARED_HEAD_HEAD = auto()
NEXTN_SHARED_HEAD_NORM = auto() NEXTN_SHARED_HEAD_NORM = auto()
# lfm2 audio
A_ENC_NORM_CONV = auto()
A_ENC_LINEAR_POS = auto()
A_ENC_POS_BIAS_U = auto()
A_ENC_POS_BIAS_V = auto()
A_ENC_OUT = auto()
A_ENC_CONV_DW = auto() # SSM conv
A_ENC_CONV_NORM = auto() # SSM conv
A_ENC_CONV_PW1 = auto()
A_ENC_CONV_PW2 = auto()
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@ -1064,7 +1081,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.V_TOK_BOI: "v.boi", MODEL_TENSOR.V_TOK_BOI: "v.boi",
MODEL_TENSOR.V_TOK_EOI: "v.eoi", MODEL_TENSOR.V_TOK_EOI: "v.eoi",
# audio (mtmd) # audio (mtmd)
# note: all audio tensor names must use prefix "a." or "mm.a."
MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd", MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd",
MODEL_TENSOR.A_ENC_EMBD_NORM: "a.position_embd_norm",
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS: "a.embd_to_logits",
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}", MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
MODEL_TENSOR.A_PRE_NORM: "a.pre_ln", MODEL_TENSOR.A_PRE_NORM: "a.pre_ln",
MODEL_TENSOR.A_POST_NORM: "a.post_ln", MODEL_TENSOR.A_POST_NORM: "a.post_ln",
@ -1074,13 +1094,28 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.A_ENC_INPUT_NORM: "a.blk.{bid}.ln1", MODEL_TENSOR.A_ENC_INPUT_NORM: "a.blk.{bid}.ln1",
MODEL_TENSOR.A_ENC_OUTPUT: "a.blk.{bid}.attn_out", MODEL_TENSOR.A_ENC_OUTPUT: "a.blk.{bid}.attn_out",
MODEL_TENSOR.A_ENC_OUTPUT_NORM: "a.blk.{bid}.ln2", MODEL_TENSOR.A_ENC_OUTPUT_NORM: "a.blk.{bid}.ln2",
MODEL_TENSOR.A_ENC_FFN_NORM: "a.blk.{bid}.ffn_norm",
MODEL_TENSOR.A_ENC_FFN_UP: "a.blk.{bid}.ffn_up", MODEL_TENSOR.A_ENC_FFN_UP: "a.blk.{bid}.ffn_up",
MODEL_TENSOR.A_ENC_FFN_GATE: "a.blk.{bid}.ffn_gate", MODEL_TENSOR.A_ENC_FFN_GATE: "a.blk.{bid}.ffn_gate",
MODEL_TENSOR.A_ENC_FFN_DOWN: "a.blk.{bid}.ffn_down", MODEL_TENSOR.A_ENC_FFN_DOWN: "a.blk.{bid}.ffn_down",
MODEL_TENSOR.A_ENC_FFN_NORM_1: "a.blk.{bid}.ffn_norm_1",
MODEL_TENSOR.A_ENC_FFN_UP_1: "a.blk.{bid}.ffn_up_1",
MODEL_TENSOR.A_ENC_FFN_GATE_1: "a.blk.{bid}.ffn_gate_1",
MODEL_TENSOR.A_ENC_FFN_DOWN_1: "a.blk.{bid}.ffn_down_1",
MODEL_TENSOR.A_MMPROJ: "mm.a.mlp.{bid}", MODEL_TENSOR.A_MMPROJ: "mm.a.mlp.{bid}",
MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc", MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc",
MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre", MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre",
MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid", MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid",
# lfm2 audio
MODEL_TENSOR.A_ENC_NORM_CONV: "a.blk.{bid}.norm_conv",
MODEL_TENSOR.A_ENC_LINEAR_POS: "a.blk.{bid}.linear_pos",
MODEL_TENSOR.A_ENC_POS_BIAS_U: "a.blk.{bid}.pos_bias_u",
MODEL_TENSOR.A_ENC_POS_BIAS_V: "a.blk.{bid}.pos_bias_v",
MODEL_TENSOR.A_ENC_OUT: "a.pre_encode.out",
MODEL_TENSOR.A_ENC_CONV_DW: "a.blk.{bid}.conv_dw",
MODEL_TENSOR.A_ENC_CONV_NORM: "a.blk.{bid}.conv_norm",
MODEL_TENSOR.A_ENC_CONV_PW1: "a.blk.{bid}.conv_pw1",
MODEL_TENSOR.A_ENC_CONV_PW2: "a.blk.{bid}.conv_pw2",
# NextN/MTP # NextN/MTP
MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.nextn.eh_proj", MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.nextn.eh_proj",
MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.nextn.embed_tokens", MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.nextn.embed_tokens",
@ -1145,6 +1180,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.V_TOK_EOI, MODEL_TENSOR.V_TOK_EOI,
# audio # audio
MODEL_TENSOR.A_ENC_EMBD_POS, MODEL_TENSOR.A_ENC_EMBD_POS,
MODEL_TENSOR.A_ENC_EMBD_NORM,
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS,
MODEL_TENSOR.A_ENC_CONV1D, MODEL_TENSOR.A_ENC_CONV1D,
MODEL_TENSOR.A_PRE_NORM, MODEL_TENSOR.A_PRE_NORM,
MODEL_TENSOR.A_POST_NORM, MODEL_TENSOR.A_POST_NORM,
@ -1154,13 +1191,27 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.A_ENC_INPUT_NORM, MODEL_TENSOR.A_ENC_INPUT_NORM,
MODEL_TENSOR.A_ENC_OUTPUT, MODEL_TENSOR.A_ENC_OUTPUT,
MODEL_TENSOR.A_ENC_OUTPUT_NORM, MODEL_TENSOR.A_ENC_OUTPUT_NORM,
MODEL_TENSOR.A_ENC_FFN_NORM,
MODEL_TENSOR.A_ENC_FFN_UP, MODEL_TENSOR.A_ENC_FFN_UP,
MODEL_TENSOR.A_ENC_FFN_GATE, MODEL_TENSOR.A_ENC_FFN_GATE,
MODEL_TENSOR.A_ENC_FFN_DOWN, MODEL_TENSOR.A_ENC_FFN_DOWN,
MODEL_TENSOR.A_ENC_FFN_NORM_1,
MODEL_TENSOR.A_ENC_FFN_UP_1,
MODEL_TENSOR.A_ENC_FFN_GATE_1,
MODEL_TENSOR.A_ENC_FFN_DOWN_1,
MODEL_TENSOR.A_MMPROJ, MODEL_TENSOR.A_MMPROJ,
MODEL_TENSOR.A_MMPROJ_FC, MODEL_TENSOR.A_MMPROJ_FC,
MODEL_TENSOR.A_MM_NORM_PRE, MODEL_TENSOR.A_MM_NORM_PRE,
MODEL_TENSOR.A_MM_NORM_MID, MODEL_TENSOR.A_MM_NORM_MID,
MODEL_TENSOR.A_ENC_NORM_CONV,
MODEL_TENSOR.A_ENC_LINEAR_POS,
MODEL_TENSOR.A_ENC_POS_BIAS_U,
MODEL_TENSOR.A_ENC_POS_BIAS_V,
MODEL_TENSOR.A_ENC_OUT,
MODEL_TENSOR.A_ENC_CONV_DW,
MODEL_TENSOR.A_ENC_CONV_NORM,
MODEL_TENSOR.A_ENC_CONV_PW1,
MODEL_TENSOR.A_ENC_CONV_PW2,
], ],
MODEL_ARCH.LLAMA: [ MODEL_ARCH.LLAMA: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
@ -3363,6 +3414,7 @@ class VisionProjectorType:
LIGHTONOCR = "lightonocr" LIGHTONOCR = "lightonocr"
COGVLM = "cogvlm" COGVLM = "cogvlm"
JANUS_PRO = "janus_pro" JANUS_PRO = "janus_pro"
LFM2A = "lfm2a" # audio
GLM4V = "glm4v" GLM4V = "glm4v"

View File

@ -1535,10 +1535,20 @@ class TensorNameMap:
MODEL_TENSOR.A_ENC_EMBD_POS: ( MODEL_TENSOR.A_ENC_EMBD_POS: (
"audio_tower.embed_positions", # ultravox "audio_tower.embed_positions", # ultravox
"audio_embedding.embedding", # lfm2
),
MODEL_TENSOR.A_ENC_EMBD_NORM: (
"audio_embedding.embedding_norm", # lfm2
),
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS: (
"audio_embedding.to_logits", # lfm2
), ),
MODEL_TENSOR.A_ENC_CONV1D: ( MODEL_TENSOR.A_ENC_CONV1D: (
"audio_tower.conv{bid}", # ultravox "audio_tower.conv{bid}", # ultravox
"conformer.pre_encode.conv.{bid}", # lfm2
), ),
MODEL_TENSOR.A_PRE_NORM: (), MODEL_TENSOR.A_PRE_NORM: (),
@ -1550,36 +1560,76 @@ class TensorNameMap:
MODEL_TENSOR.A_ENC_ATTN_Q: ( MODEL_TENSOR.A_ENC_ATTN_Q: (
"audio_tower.layers.{bid}.self_attn.q_proj", # ultravox "audio_tower.layers.{bid}.self_attn.q_proj", # ultravox
"conformer.layers.{bid}.self_attn.linear_q", # lfm2
), ),
MODEL_TENSOR.A_ENC_ATTN_K: ( MODEL_TENSOR.A_ENC_ATTN_K: (
"audio_tower.layers.{bid}.self_attn.k_proj", # ultravox "audio_tower.layers.{bid}.self_attn.k_proj", # ultravox
"conformer.layers.{bid}.self_attn.linear_k", # lfm2
), ),
MODEL_TENSOR.A_ENC_ATTN_V: ( MODEL_TENSOR.A_ENC_ATTN_V: (
"audio_tower.layers.{bid}.self_attn.v_proj", # ultravox "audio_tower.layers.{bid}.self_attn.v_proj", # ultravox
"conformer.layers.{bid}.self_attn.linear_v", # lfm2
), ),
MODEL_TENSOR.A_ENC_INPUT_NORM: ( MODEL_TENSOR.A_ENC_INPUT_NORM: (
"audio_tower.layers.{bid}.self_attn_layer_norm", # ultravox "audio_tower.layers.{bid}.self_attn_layer_norm", # ultravox
"conformer.layers.{bid}.norm_self_att", # lfm2
), ),
MODEL_TENSOR.A_ENC_OUTPUT: ( MODEL_TENSOR.A_ENC_OUTPUT: (
"audio_tower.layers.{bid}.self_attn.out_proj", # ultravox "audio_tower.layers.{bid}.self_attn.out_proj", # ultravox
"conformer.layers.{bid}.self_attn.linear_out", # lfm2
), ),
MODEL_TENSOR.A_ENC_OUTPUT_NORM: ( MODEL_TENSOR.A_ENC_OUTPUT_NORM: (
"audio_tower.layers.{bid}.final_layer_norm", # ultravox "audio_tower.layers.{bid}.final_layer_norm", # ultravox
"conformer.layers.{bid}.norm_out", # lfm2
),
MODEL_TENSOR.A_ENC_FFN_NORM: (
"conformer.layers.{bid}.norm_feed_forward1", # lfm2
), ),
MODEL_TENSOR.A_ENC_FFN_UP: ( MODEL_TENSOR.A_ENC_FFN_UP: (
"audio_tower.layers.{bid}.fc1", # ultravox "audio_tower.layers.{bid}.fc1", # ultravox
"conformer.layers.{bid}.feed_forward1.linear1", # lfm2
), ),
MODEL_TENSOR.A_ENC_FFN_GATE: (), MODEL_TENSOR.A_ENC_FFN_GATE: (),
MODEL_TENSOR.A_ENC_FFN_DOWN: ( MODEL_TENSOR.A_ENC_FFN_DOWN: (
"audio_tower.layers.{bid}.fc2", # ultravox "audio_tower.layers.{bid}.fc2", # ultravox
"conformer.layers.{bid}.feed_forward1.linear2", # lfm2
),
MODEL_TENSOR.A_ENC_FFN_UP_1: (
"conformer.layers.{bid}.feed_forward2.linear1", # lfm2
),
MODEL_TENSOR.A_ENC_FFN_DOWN_1: (
"conformer.layers.{bid}.feed_forward2.linear2", # lfm2
),
MODEL_TENSOR.A_ENC_FFN_NORM_1: (
"conformer.layers.{bid}.norm_feed_forward2", # lfm2
),
MODEL_TENSOR.A_ENC_LINEAR_POS: (
"conformer.layers.{bid}.self_attn.linear_pos", # lfm2
),
MODEL_TENSOR.A_ENC_POS_BIAS_U: (
"conformer.layers.{bid}.self_attn.pos_bias_u", # lfm2
),
MODEL_TENSOR.A_ENC_POS_BIAS_V: (
"conformer.layers.{bid}.self_attn.pos_bias_v", # lfm2
),
MODEL_TENSOR.A_ENC_OUT: (
"conformer.pre_encode.out", # lfm2
), ),
# note: some tensors below has "audio." pseudo-prefix, to prevent conflicts with vision tensors # note: some tensors below has "audio." pseudo-prefix, to prevent conflicts with vision tensors
@ -1587,6 +1637,7 @@ class TensorNameMap:
MODEL_TENSOR.A_MMPROJ: ( MODEL_TENSOR.A_MMPROJ: (
"audio.multi_modal_projector.linear_{bid}", # ultravox "audio.multi_modal_projector.linear_{bid}", # ultravox
"audio_adapter.model.{bid}" # lfm2
), ),
MODEL_TENSOR.A_MMPROJ_FC: ( MODEL_TENSOR.A_MMPROJ_FC: (
@ -1602,6 +1653,26 @@ class TensorNameMap:
"audio.multi_modal_projector.ln_mid", # ultravox "audio.multi_modal_projector.ln_mid", # ultravox
), ),
MODEL_TENSOR.A_ENC_CONV_DW: (
"conformer.layers.{bid}.conv.depthwise_conv", # lfm2
),
MODEL_TENSOR.A_ENC_CONV_NORM: (
"conformer.layers.{bid}.conv.batch_norm", # lfm2
),
MODEL_TENSOR.A_ENC_CONV_PW1: (
"conformer.layers.{bid}.conv.pointwise_conv1", # lfm2
),
MODEL_TENSOR.A_ENC_CONV_PW2: (
"conformer.layers.{bid}.conv.pointwise_conv2", # lfm2
),
MODEL_TENSOR.A_ENC_NORM_CONV: (
"conformer.layers.{bid}.norm_conv", # lfm2
),
# NextN/MTP tensors for GLM4_MOE # NextN/MTP tensors for GLM4_MOE
MODEL_TENSOR.NEXTN_EH_PROJ: ( MODEL_TENSOR.NEXTN_EH_PROJ: (
"model.layers.{bid}.eh_proj", "model.layers.{bid}.eh_proj",

View File

@ -1086,10 +1086,10 @@ bool llama_model_loader::load_all_data(
} else { } else {
// If upload_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU. // If upload_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU.
if (upload_backend) { if (upload_backend) {
auto offset = (off_t) weight->offs; size_t offset = weight->offs;
alignment = file->read_alignment(); alignment = file->read_alignment();
off_t aligned_offset = offset & ~(alignment - 1); size_t aligned_offset = offset & ~(alignment - 1);
off_t offset_from_alignment = offset - aligned_offset; size_t offset_from_alignment = offset - aligned_offset;
file->seek(aligned_offset, SEEK_SET); file->seek(aligned_offset, SEEK_SET);
// Calculate aligned read boundaries // Calculate aligned read boundaries

View File

@ -37,6 +37,30 @@ int main(void) {
exit(1); exit(1);
} }
} }
// ensure shorter argument precedes longer argument
if (opt.args.size() > 1) {
const std::string first(opt.args.front());
const std::string last(opt.args.back());
if (first.length() > last.length()) {
fprintf(stderr, "test-arg-parser: shorter argument should come before longer one: %s, %s\n",
first.c_str(), last.c_str());
assert(false);
}
}
// same check for negated arguments
if (opt.args_neg.size() > 1) {
const std::string first(opt.args_neg.front());
const std::string last(opt.args_neg.back());
if (first.length() > last.length()) {
fprintf(stderr, "test-arg-parser: shorter negated argument should come before longer one: %s, %s\n",
first.c_str(), last.c_str());
assert(false);
}
}
} }
} catch (std::exception & e) { } catch (std::exception & e) {
printf("%s\n", e.what()); printf("%s\n", e.what());

View File

@ -7295,11 +7295,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f)); test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
for (int64_t d_conv : {3, 4}) { for (int64_t d_conv : {3, 4, 9}) {
for (int64_t d_inner: {1024, 1536, 2048}) { for (int64_t d_inner: {1024, 1536, 2048}) {
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, d_inner, 1, 1}, {d_conv, d_inner, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, d_inner, 1, 1}, {d_conv, d_inner, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {2 * d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, d_inner, 4, 1}, {d_conv, d_inner, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}));
} }
} }

View File

@ -15,6 +15,7 @@ add_library(mtmd
clip-graph.h clip-graph.h
models/models.h models/models.h
models/cogvlm.cpp models/cogvlm.cpp
models/conformer.cpp
models/glm4v.cpp models/glm4v.cpp
models/internvl.cpp models/internvl.cpp
models/kimivl.cpp models/kimivl.cpp

View File

@ -138,6 +138,21 @@
#define TN_TOK_BOI "v.boi" #define TN_TOK_BOI "v.boi"
#define TN_TOK_EOI "v.eoi" #define TN_TOK_EOI "v.eoi"
// (conformer) lfm2
#define TN_PRE_ENCODE_OUT "a.pre_encode.out.%s"
#define TN_FFN_NORM "%s.blk.%d.ffn_norm.%s"
#define TN_FFN_NORM_1 "%s.blk.%d.ffn_norm_1.%s"
#define TN_FFN_UP_1 "%s.blk.%d.ffn_up_1.%s"
#define TN_FFN_DOWN_1 "%s.blk.%d.ffn_down_1.%s"
#define TN_POS_BIAS_U "%s.blk.%d.pos_bias_u"
#define TN_POS_BIAS_V "%s.blk.%d.pos_bias_v"
#define TN_NORM_CONV "%s.blk.%d.norm_conv.%s"
#define TN_LINEAR_POS "%s.blk.%d.linear_pos.%s"
#define TN_CONV_DW "%s.blk.%d.conv_dw.%s"
#define TN_CONV_NORM "%s.blk.%d.conv_norm.%s"
#define TN_CONV_PW1 "%s.blk.%d.conv_pw1.%s"
#define TN_CONV_PW2 "%s.blk.%d.conv_pw2.%s"
// align x to upper multiple of n // align x to upper multiple of n
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) #define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
@ -170,6 +185,7 @@ enum projector_type {
PROJECTOR_TYPE_LIGHTONOCR, PROJECTOR_TYPE_LIGHTONOCR,
PROJECTOR_TYPE_COGVLM, PROJECTOR_TYPE_COGVLM,
PROJECTOR_TYPE_JANUS_PRO, PROJECTOR_TYPE_JANUS_PRO,
PROJECTOR_TYPE_LFM2A,
PROJECTOR_TYPE_GLM4V, PROJECTOR_TYPE_GLM4V,
PROJECTOR_TYPE_UNKNOWN, PROJECTOR_TYPE_UNKNOWN,
}; };
@ -198,6 +214,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
{ PROJECTOR_TYPE_COGVLM, "cogvlm"}, { PROJECTOR_TYPE_COGVLM, "cogvlm"},
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
{ PROJECTOR_TYPE_LFM2A, "lfm2a"},
{ PROJECTOR_TYPE_GLM4V, "glm4v"}, { PROJECTOR_TYPE_GLM4V, "glm4v"},
}; };

View File

@ -4,6 +4,7 @@
#include "clip.h" #include "clip.h"
#include "clip-impl.h" #include "clip-impl.h"
#include <array>
#include <vector> #include <vector>
#include <unordered_set> #include <unordered_set>
#include <cstdint> #include <cstdint>
@ -142,6 +143,30 @@ struct clip_layer {
ggml_tensor * deepstack_fc2_w = nullptr; ggml_tensor * deepstack_fc2_w = nullptr;
ggml_tensor * deepstack_fc2_b = nullptr; ggml_tensor * deepstack_fc2_b = nullptr;
// lfm2
ggml_tensor * ff_norm_w = nullptr;
ggml_tensor * ff_norm_b = nullptr;
ggml_tensor * ff_norm_1_w = nullptr;
ggml_tensor * ff_norm_1_b = nullptr;
ggml_tensor * ff_up_1_w = nullptr;
ggml_tensor * ff_up_1_b = nullptr;
ggml_tensor * ff_down_1_w = nullptr;
ggml_tensor * ff_down_1_b = nullptr;
ggml_tensor * pos_bias_u = nullptr;
ggml_tensor * pos_bias_v = nullptr;
ggml_tensor * norm_conv_w = nullptr;
ggml_tensor * norm_conv_b = nullptr;
ggml_tensor * linear_pos_w = nullptr;
ggml_tensor * conv_norm_w = nullptr;
ggml_tensor * conv_norm_b = nullptr;
ggml_tensor * conv_dw_w = nullptr;
ggml_tensor * conv_dw_b = nullptr;
ggml_tensor * conv_pw1_w = nullptr;
ggml_tensor * conv_pw1_b = nullptr;
ggml_tensor * conv_pw2_w = nullptr;
ggml_tensor * conv_pw2_b = nullptr;
bool has_deepstack() const { bool has_deepstack() const {
return deepstack_fc1_w != nullptr; return deepstack_fc1_w != nullptr;
} }
@ -286,6 +311,12 @@ struct clip_model {
ggml_tensor * mm_boi = nullptr; ggml_tensor * mm_boi = nullptr;
ggml_tensor * mm_eoi = nullptr; ggml_tensor * mm_eoi = nullptr;
// lfm2 audio
std::array<ggml_tensor *, 7> pre_encode_conv_X_w = {nullptr};
std::array<ggml_tensor *, 7> pre_encode_conv_X_b = {nullptr};
ggml_tensor * pre_encode_out_w = nullptr;
ggml_tensor * pre_encode_out_b = nullptr;
bool audio_has_avgpool() const { bool audio_has_avgpool() const {
return proj_type == PROJECTOR_TYPE_QWEN2A return proj_type == PROJECTOR_TYPE_QWEN2A
|| proj_type == PROJECTOR_TYPE_VOXTRAL; || proj_type == PROJECTOR_TYPE_VOXTRAL;

View File

@ -837,6 +837,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{ {
builder = std::make_unique<clip_graph_llava>(ctx, img); builder = std::make_unique<clip_graph_llava>(ctx, img);
} break; } break;
case PROJECTOR_TYPE_LFM2A:
{
builder = std::make_unique<clip_graph_conformer>(ctx, img);
} break;
case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_GLM4V:
{ {
builder = std::make_unique<clip_graph_glm4v>(ctx, img); builder = std::make_unique<clip_graph_glm4v>(ctx, img);
@ -1187,6 +1191,15 @@ struct clip_model_loader {
hparams.audio_window_len = 400; hparams.audio_window_len = 400;
hparams.audio_hop_len = 160; hparams.audio_hop_len = 160;
} break; } break;
case PROJECTOR_TYPE_LFM2A:
{
// audio preprocessing params
hparams.audio_chunk_len = 1; // in seconds
hparams.audio_sample_rate = 16000;
hparams.audio_n_fft = 512;
hparams.audio_window_len = 400;
hparams.audio_hop_len = 160;
} break;
default: default:
break; break;
} }
@ -1611,6 +1624,52 @@ struct clip_model_loader {
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
} break; } break;
case PROJECTOR_TYPE_LFM2A:
{
for (int i : {0, 2, 3, 5, 6}) {
model.pre_encode_conv_X_w[i] = get_tensor(string_format(TN_CONV1D, i, "weight"));
model.pre_encode_conv_X_b[i] = get_tensor(string_format(TN_CONV1D, i, "bias"));
}
model.pre_encode_out_w = get_tensor(string_format(TN_PRE_ENCODE_OUT, "weight"));
model.pre_encode_out_b = get_tensor(string_format(TN_PRE_ENCODE_OUT, "bias"));
model.mm_0_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 0, "weight"));
model.mm_0_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 0, "bias"));
model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias"));
model.mm_3_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 3, "weight"));
model.mm_3_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 3, "bias"));
for (int il = 0; il < hparams.n_layer; ++il) {
auto & layer = model.layers[il];
layer.ff_norm_w = get_tensor(string_format(TN_FFN_NORM, prefix, il, "weight"));
layer.ff_norm_b = get_tensor(string_format(TN_FFN_NORM, prefix, il, "bias"));
layer.ff_norm_1_w = get_tensor(string_format(TN_FFN_NORM_1, prefix, il, "weight"));
layer.ff_norm_1_b = get_tensor(string_format(TN_FFN_NORM_1, prefix, il, "bias"));
layer.ff_up_1_w = get_tensor(string_format(TN_FFN_UP_1, prefix, il, "weight"));
layer.ff_up_1_b = get_tensor(string_format(TN_FFN_UP_1, prefix, il, "bias"));
layer.ff_down_1_w = get_tensor(string_format(TN_FFN_DOWN_1, prefix, il, "weight"));
layer.ff_down_1_b = get_tensor(string_format(TN_FFN_DOWN_1, prefix, il, "bias"));
layer.pos_bias_u = get_tensor(string_format(TN_POS_BIAS_U, prefix, il));
layer.pos_bias_v = get_tensor(string_format(TN_POS_BIAS_V, prefix, il));
layer.norm_conv_w = get_tensor(string_format(TN_NORM_CONV, prefix, il, "weight"));
layer.norm_conv_b = get_tensor(string_format(TN_NORM_CONV, prefix, il, "bias"));
layer.linear_pos_w = get_tensor(string_format(TN_LINEAR_POS, prefix, il, "weight"));
layer.conv_norm_w = get_tensor(string_format(TN_CONV_NORM, prefix, il, "weight"));
layer.conv_norm_b = get_tensor(string_format(TN_CONV_NORM, prefix, il, "bias"));
layer.conv_dw_w = get_tensor(string_format(TN_CONV_DW, prefix, il, "weight"));
layer.conv_dw_b = get_tensor(string_format(TN_CONV_DW, prefix, il, "bias"));
layer.conv_pw1_w = get_tensor(string_format(TN_CONV_PW1, prefix, il, "weight"));
layer.conv_pw1_b = get_tensor(string_format(TN_CONV_PW1, prefix, il, "bias"));
layer.conv_pw2_w = get_tensor(string_format(TN_CONV_PW2, prefix, il, "weight"));
layer.conv_pw2_b = get_tensor(string_format(TN_CONV_PW2, prefix, il, "bias"));
}
} break;
default: default:
GGML_ASSERT(false && "unknown projector type"); GGML_ASSERT(false && "unknown projector type");
} }
@ -3004,6 +3063,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
{ {
n_patches += 2; // for BOI and EOI token embeddings n_patches += 2; // for BOI and EOI token embeddings
} break; } break;
case PROJECTOR_TYPE_LFM2A:
{
n_patches = ((((img->nx + 1) / 2) + 1) / 2 + 1) / 2;
} break;
default: default:
GGML_ABORT("unsupported projector type"); GGML_ABORT("unsupported projector type");
} }
@ -3362,6 +3425,27 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
} }
set_input_i32("pos_w", pos_data); set_input_i32("pos_w", pos_data);
} break; } break;
case PROJECTOR_TYPE_LFM2A:
{
GGML_ASSERT(imgs.entries.size() == 1);
const auto n_frames = clip_n_output_tokens(ctx, imgs.entries.front().get());
auto d_model = 512;
auto seq_len = n_frames * 2 - 1;
std::vector<float> pos_emb(d_model*seq_len);
std::vector<double> inv_freq(d_model / 2);
for (size_t i = 0; i < inv_freq.size(); ++i) {
inv_freq[i] = std::exp(-(std::log(10000.0) / (float)d_model) * (2.0f * (float)(i)));
}
for (int64_t pos = 0; pos < seq_len; ++pos) {
for (size_t i = 0; i < inv_freq.size(); ++i) {
const float ang = (n_frames - pos - 1) * inv_freq[i];
pos_emb[pos*d_model + 2*i + 0] = sinf(ang); // even
pos_emb[pos*d_model + 2*i + 1] = cosf(ang); // odd
}
}
set_input_f32("pos_emb", pos_emb);
} break;
default: default:
GGML_ABORT("Unknown projector type"); GGML_ABORT("Unknown projector type");
} }
@ -3456,6 +3540,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->model.mm_2_w->ne[1]; return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_COGVLM: case PROJECTOR_TYPE_COGVLM:
return ctx->model.mm_4h_to_h_w->ne[1]; return ctx->model.mm_4h_to_h_w->ne[1];
case PROJECTOR_TYPE_LFM2A:
return ctx->model.position_embeddings->ne[0];
case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_GLM4V:
return ctx->model.mm_ffn_down_w->ne[1]; return ctx->model.mm_ffn_down_w->ne[1];
default: default:

View File

@ -0,0 +1,217 @@
#include "models.h"
ggml_cgraph * clip_graph_conformer::build() {
const int n_frames = img.nx;
const int n_pos = n_frames / 2;
const int n_pos_embd = (((((n_frames + 1) / 2) + 1) / 2 + 1) / 2) * 2 - 1;
GGML_ASSERT(model.position_embeddings->ne[1] >= n_pos);
ggml_tensor * pos_emb = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 512, n_pos_embd);
ggml_set_name(pos_emb, "pos_emb");
ggml_set_input(pos_emb);
ggml_build_forward_expand(gf, pos_emb);
ggml_tensor * inp = build_inp_raw(1);
cb(inp, "input", -1);
auto * cur = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
// pre encode, conv subsampling
{
// layer.0 - conv2d
cur = ggml_conv_2d(ctx0, model.pre_encode_conv_X_w[0], cur, 2, 2, 1, 1, 1, 1);
cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[0]);
cb(cur, "conformer.pre_encode.conv.{}", 0);
// layer.1 - relu
cur = ggml_relu_inplace(ctx0, cur);
// layer.2 conv2d dw
cur = ggml_conv_2d_dw_direct(ctx0, model.pre_encode_conv_X_w[2], cur, 2, 2, 1, 1, 1, 1);
cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[2]);
cb(cur, "conformer.pre_encode.conv.{}", 2);
// layer.3 conv2d
cur = ggml_conv_2d_direct(ctx0, model.pre_encode_conv_X_w[3], cur, 1, 1, 0, 0, 1, 1);
cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[3]);
cb(cur, "conformer.pre_encode.conv.{}", 3);
// layer.4 - relu
cur = ggml_relu_inplace(ctx0, cur);
// layer.5 conv2d dw
cur = ggml_conv_2d_dw_direct(ctx0, model.pre_encode_conv_X_w[5], cur, 2, 2, 1, 1, 1, 1);
cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[5]);
cb(cur, "conformer.pre_encode.conv.{}", 5);
// layer.6 conv2d
cur = ggml_conv_2d_direct(ctx0, model.pre_encode_conv_X_w[6], cur, 1, 1, 0, 0, 1, 1);
cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[6]);
cb(cur, "conformer.pre_encode.conv.{}", 6);
// layer.7 - relu
cur = ggml_relu_inplace(ctx0, cur);
// flatten channel and frequency axis
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3));
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2]);
// calculate out
cur = ggml_mul_mat(ctx0, model.pre_encode_out_w, cur);
cur = ggml_add(ctx0, cur, model.pre_encode_out_b);
cb(cur, "conformer.pre_encode.out", -1);
}
// pos_emb
cb(pos_emb, "pos_emb", -1);
for (int il = 0; il < hparams.n_layer; il++) {
const auto & layer = model.layers[il];
auto * residual = cur;
cb(cur, "layer.in", il);
// feed_forward1
cur = build_norm(cur, layer.ff_norm_w, layer.ff_norm_b, NORM_TYPE_NORMAL, 1e-5, il);
cb(cur, "conformer.layers.{}.norm_feed_forward1", il);
cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, nullptr, nullptr, layer.ff_down_w, layer.ff_down_b, FFN_SILU,
il);
cb(cur, "conformer.layers.{}.feed_forward1.linear2", il);
const auto fc_factor = 0.5f;
residual = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, fc_factor));
// self-attention
{
cur = build_norm(residual, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, 1e-5, il);
cb(cur, "conformer.layers.{}.norm_self_att", il);
ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, Qcur->ne[1]);
ggml_tensor * Q_bias_u = ggml_add(ctx0, Qcur, layer.pos_bias_u);
Q_bias_u = ggml_permute(ctx0, Q_bias_u, 0, 2, 1, 3);
ggml_tensor * Q_bias_v = ggml_add(ctx0, Qcur, layer.pos_bias_v);
Q_bias_v = ggml_permute(ctx0, Q_bias_v, 0, 2, 1, 3);
// TODO @ngxson : some cont can/should be removed when ggml_mul_mat support these cases
ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, Kcur->ne[1]);
Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, Vcur->ne[1]);
Vcur = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 1, 2, 0, 3));
// build_attn won't fit due to matrix_ac and matrix_bd separation
ggml_tensor * matrix_ac = ggml_mul_mat(ctx0, Q_bias_u, Kcur);
matrix_ac = ggml_cont(ctx0, ggml_permute(ctx0, matrix_ac, 1, 0, 2, 3));
cb(matrix_ac, "conformer.layers.{}.self_attn.id3", il);
auto * p = ggml_mul_mat(ctx0, layer.linear_pos_w, pos_emb);
cb(p, "conformer.layers.{}.self_attn.linear_pos", il);
p = ggml_reshape_3d(ctx0, p, d_head, n_head, p->ne[1]);
p = ggml_permute(ctx0, p, 0, 2, 1, 3);
auto * matrix_bd = ggml_mul_mat(ctx0, Q_bias_v, p);
matrix_bd = ggml_cont(ctx0, ggml_permute(ctx0, matrix_bd, 1, 0, 2, 3));
// rel shift
{
const auto pos_len = matrix_bd->ne[0];
const auto q_len = matrix_bd->ne[1];
const auto h = matrix_bd->ne[2];
matrix_bd = ggml_pad(ctx0, matrix_bd, 1, 0, 0, 0);
matrix_bd = ggml_roll(ctx0, matrix_bd, 1, 0, 0, 0);
matrix_bd = ggml_reshape_3d(ctx0, matrix_bd, q_len, pos_len + 1, h);
matrix_bd = ggml_view_3d(ctx0, matrix_bd, q_len, pos_len, h, matrix_bd->nb[1],
matrix_bd->nb[2], matrix_bd->nb[0] * q_len);
matrix_bd = ggml_cont_3d(ctx0, matrix_bd, pos_len, q_len, h);
}
matrix_bd = ggml_view_3d(ctx0, matrix_bd, matrix_ac->ne[0], matrix_bd->ne[1],
matrix_bd->ne[2], matrix_bd->nb[1], matrix_bd->nb[2], 0);
auto * scores = ggml_add(ctx0, matrix_ac, matrix_bd);
scores = ggml_scale(ctx0, scores, 1.0f / std::sqrt(d_head));
cb(scores, "conformer.layers.{}.self_attn.id0", il);
ggml_tensor * attn = ggml_soft_max(ctx0, scores);
ggml_tensor * x = ggml_mul_mat(ctx0, attn, Vcur);
x = ggml_permute(ctx0, x, 2, 0, 1, 3);
x = ggml_cont_2d(ctx0, x, x->ne[0] * x->ne[1], x->ne[2]);
ggml_tensor * out = ggml_mul_mat(ctx0, layer.o_w, x);
out = ggml_add(ctx0, out, layer.o_b);
cb(out, "conformer.layers.{}.self_attn.linear_out", il);
cur = out;
}
residual = ggml_add(ctx0, residual, cur);
cur = build_norm(residual, layer.norm_conv_w, layer.norm_conv_b, NORM_TYPE_NORMAL, 1e-5, il);
cb(cur, "conformer.layers.{}.norm_conv", il);
// conv
{
auto * x = cur;
x = ggml_mul_mat(ctx0, layer.conv_pw1_w, x);
x = ggml_add(ctx0, x, layer.conv_pw1_b);
cb(x, "conformer.layers.{}.conv.pointwise_conv1", il);
// ggml_glu doesn't support sigmoid
// TODO @ngxson : support this ops in ggml
{
int64_t d = x->ne[0] / 2;
ggml_tensor * gate = ggml_sigmoid(ctx0, ggml_view_2d(ctx0, x, d, x->ne[1], x->nb[1], d * x->nb[0]));
x = ggml_mul(ctx0, ggml_view_2d(ctx0, x, d, x->ne[1], x->nb[1], 0), gate);
x = ggml_cont(ctx0, ggml_transpose(ctx0, x));
}
// use ggml_ssm_conv for f32 precision
x = ggml_pad(ctx0, x, 4, 0, 0, 0);
x = ggml_roll(ctx0, x, 4, 0, 0, 0);
x = ggml_pad(ctx0, x, 4, 0, 0, 0);
x = ggml_ssm_conv(ctx0, x, layer.conv_dw_w);
x = ggml_add(ctx0, x, layer.conv_dw_b);
x = ggml_add(ctx0, ggml_mul(ctx0, x, layer.conv_norm_w), layer.conv_norm_b);
x = ggml_silu(ctx0, x);
// pointwise_conv2
x = ggml_mul_mat(ctx0, layer.conv_pw2_w, x);
x = ggml_add(ctx0, x, layer.conv_pw2_b);
cur = x;
}
residual = ggml_add(ctx0, residual, cur);
cur = build_norm(residual, layer.ff_norm_1_w, layer.ff_norm_1_b, NORM_TYPE_NORMAL, 1e-5, il);
cb(cur, "conformer.layers.{}.norm_feed_forward2", il);
cur = build_ffn(cur, layer.ff_up_1_w, layer.ff_up_1_b, nullptr, nullptr, layer.ff_down_1_w, layer.ff_down_1_b,
FFN_SILU, il); // TODO(tarek): read activation for ffn from hparams
cb(cur, "conformer.layers.{}.feed_forward2.linear2", il);
residual = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, fc_factor));
cb(residual, "conformer.layers.{}.conv.id", il);
cur = build_norm(residual, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, 1e-5, il);
cb(cur, "conformer.layers.{}.norm_out", il);
}
// audio adapter
cur = build_norm(cur, model.mm_0_w, model.mm_0_b, NORM_TYPE_NORMAL, 1e-5, -1);
cb(cur, "audio_adapter.model.{}", 0);
cur = build_ffn(cur, model.mm_1_w, model.mm_1_b, nullptr, nullptr, model.mm_3_w, model.mm_3_b, FFN_GELU_ERF, -1);
cb(cur, "projected", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}

View File

@ -57,6 +57,11 @@ struct clip_graph_whisper_enc : clip_graph {
ggml_cgraph * build() override; ggml_cgraph * build() override;
}; };
struct clip_graph_conformer : clip_graph {
clip_graph_conformer(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override;
};
struct clip_graph_glm4v : clip_graph { struct clip_graph_glm4v : clip_graph {
clip_graph_glm4v(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} clip_graph_glm4v(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override; ggml_cgraph * build() override;

View File

@ -535,3 +535,56 @@ bool mtmd_audio_preprocessor_whisper::preprocess(
return true; return true;
} }
//
// mtmd_audio_preprocessor_conformer
//
void mtmd_audio_preprocessor_conformer::initialize() {
g_cache.fill_sin_cos_table(hparams.audio_n_fft);
g_cache.fill_hann_window(hparams.audio_window_len, true);
g_cache.fill_mel_filterbank_matrix(
hparams.n_mel_bins,
hparams.audio_n_fft,
hparams.audio_sample_rate);
}
bool mtmd_audio_preprocessor_conformer::preprocess(
const float * samples,
size_t n_samples,
std::vector<mtmd_audio_mel> & output) {
// empty audio
if (n_samples == 0) {
return false;
}
filter_params params;
params.n_mel = hparams.n_mel_bins;
params.n_fft_bins = 1 + (hparams.audio_n_fft / 2);
params.hann_window_size = hparams.audio_window_len;
params.hop_length = hparams.audio_hop_len;
params.sample_rate = hparams.audio_sample_rate;
params.center_padding = true;
params.preemph = 0.97f;
params.use_natural_log = true;
params.norm_per_feature = true;
// make sure the global cache is initialized
GGML_ASSERT(!g_cache.sin_vals.empty());
GGML_ASSERT(!g_cache.cos_vals.empty());
GGML_ASSERT(!g_cache.filters.data.empty());
mtmd_audio_mel out_full;
bool ok = log_mel_spectrogram(
samples,
n_samples,
4, // n_threads
params,
out_full);
if (!ok) {
return false;
}
output.push_back(std::move(out_full));
return true;
}

View File

@ -32,3 +32,9 @@ struct mtmd_audio_preprocessor_whisper : mtmd_audio_preprocessor {
void initialize() override; void initialize() override;
bool preprocess(const float * samples, size_t n_samples, std::vector<mtmd_audio_mel> & output) override; bool preprocess(const float * samples, size_t n_samples, std::vector<mtmd_audio_mel> & output) override;
}; };
struct mtmd_audio_preprocessor_conformer : mtmd_audio_preprocessor {
mtmd_audio_preprocessor_conformer(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {}
void initialize() override;
bool preprocess(const float * samples, size_t n_samples, std::vector<mtmd_audio_mel> & output) override;
};

View File

@ -309,9 +309,24 @@ int main(int argc, char ** argv) {
if (g_is_interrupted) return 130; if (g_is_interrupted) return 130;
auto eval_system_prompt_if_present = [&] {
if (params.system_prompt.empty()) {
return 0;
}
common_chat_msg msg;
msg.role = "system";
msg.content = params.system_prompt;
return eval_message(ctx, msg);
};
LOG_WRN("WARN: This is an experimental CLI for testing multimodal capability.\n"); LOG_WRN("WARN: This is an experimental CLI for testing multimodal capability.\n");
LOG_WRN(" For normal use cases, please use the standard llama-cli\n"); LOG_WRN(" For normal use cases, please use the standard llama-cli\n");
if (eval_system_prompt_if_present()) {
return 1;
}
if (is_single_turn) { if (is_single_turn) {
g_is_generating = true; g_is_generating = true;
if (params.prompt.find(mtmd_default_marker()) == std::string::npos) { if (params.prompt.find(mtmd_default_marker()) == std::string::npos) {
@ -321,6 +336,7 @@ int main(int argc, char ** argv) {
params.prompt = mtmd_default_marker() + params.prompt; params.prompt = mtmd_default_marker() + params.prompt;
} }
} }
common_chat_msg msg; common_chat_msg msg;
msg.role = "user"; msg.role = "user";
msg.content = params.prompt; msg.content = params.prompt;
@ -369,6 +385,9 @@ int main(int argc, char ** argv) {
ctx.n_past = 0; ctx.n_past = 0;
ctx.chat_history.clear(); ctx.chat_history.clear();
llama_memory_clear(llama_get_memory(ctx.lctx), true); llama_memory_clear(llama_get_memory(ctx.lctx), true);
if (eval_system_prompt_if_present()) {
return 1;
}
LOG("Chat history cleared\n\n"); LOG("Chat history cleared\n\n");
continue; continue;
} }

View File

@ -332,6 +332,9 @@ struct mtmd_context {
case PROJECTOR_TYPE_GLMA: case PROJECTOR_TYPE_GLMA:
audio_preproc = std::make_unique<mtmd_audio_preprocessor_whisper>(ctx_a); audio_preproc = std::make_unique<mtmd_audio_preprocessor_whisper>(ctx_a);
break; break;
case PROJECTOR_TYPE_LFM2A:
audio_preproc = std::make_unique<mtmd_audio_preprocessor_conformer>(ctx_a);
break;
default: default:
GGML_ABORT("unsupported audio projector type"); GGML_ABORT("unsupported audio projector type");
} }

View File

@ -84,6 +84,7 @@ add_test_vision "ggml-org/LightOnOCR-1B-1025-GGUF:Q8_0"
add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0" add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0"
add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M" add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
add_test_audio "ggml-org/Voxtral-Mini-3B-2507-GGUF:Q4_K_M" add_test_audio "ggml-org/Voxtral-Mini-3B-2507-GGUF:Q4_K_M"
add_test_audio "ggml-org/LFM2-Audio-1.5B-GGUF:Q8_0"
# to test the big models, run: ./tests.sh big # to test the big models, run: ./tests.sh big
if [ "$RUN_BIG_TESTS" = true ]; then if [ "$RUN_BIG_TESTS" = true ]; then

View File

@ -75,9 +75,9 @@ For the ful list of features, please refer to [server's changelog](https://githu
| `--numa TYPE` | attempt optimizations that help on some NUMA systems<br/>- distribute: spread execution evenly over all nodes<br/>- isolate: only spawn threads on CPUs on the node that execution started on<br/>- numactl: use the CPU map provided by numactl<br/>if run without this previously, it is recommended to drop the system page cache before using this<br/>see https://github.com/ggml-org/llama.cpp/issues/1437<br/>(env: LLAMA_ARG_NUMA) | | `--numa TYPE` | attempt optimizations that help on some NUMA systems<br/>- distribute: spread execution evenly over all nodes<br/>- isolate: only spawn threads on CPUs on the node that execution started on<br/>- numactl: use the CPU map provided by numactl<br/>if run without this previously, it is recommended to drop the system page cache before using this<br/>see https://github.com/ggml-org/llama.cpp/issues/1437<br/>(env: LLAMA_ARG_NUMA) |
| `-dev, --device <dev1,dev2,..>` | comma-separated list of devices to use for offloading (none = don't offload)<br/>use --list-devices to see a list of available devices<br/>(env: LLAMA_ARG_DEVICE) | | `-dev, --device <dev1,dev2,..>` | comma-separated list of devices to use for offloading (none = don't offload)<br/>use --list-devices to see a list of available devices<br/>(env: LLAMA_ARG_DEVICE) |
| `--list-devices` | print list of available devices and exit | | `--list-devices` | print list of available devices and exit |
| `--override-tensor, -ot <tensor name pattern>=<buffer type>,...` | override tensor buffer type | | `-ot, --override-tensor <tensor name pattern>=<buffer type>,...` | override tensor buffer type |
| `--cpu-moe, -cmoe` | keep all Mixture of Experts (MoE) weights in the CPU<br/>(env: LLAMA_ARG_CPU_MOE) | | `-cmoe, --cpu-moe` | keep all Mixture of Experts (MoE) weights in the CPU<br/>(env: LLAMA_ARG_CPU_MOE) |
| `--n-cpu-moe, -ncmoe N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU<br/>(env: LLAMA_ARG_N_CPU_MOE) | | `-ncmoe, --n-cpu-moe N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU<br/>(env: LLAMA_ARG_N_CPU_MOE) |
| `-ngl, --gpu-layers, --n-gpu-layers N` | max. number of layers to store in VRAM (default: -1)<br/>(env: LLAMA_ARG_N_GPU_LAYERS) | | `-ngl, --gpu-layers, --n-gpu-layers N` | max. number of layers to store in VRAM (default: -1)<br/>(env: LLAMA_ARG_N_GPU_LAYERS) |
| `-sm, --split-mode {none,layer,row}` | how to split the model across multiple GPUs, one of:<br/>- none: use one GPU only<br/>- layer (default): split layers and KV across GPUs<br/>- row: split rows across GPUs<br/>(env: LLAMA_ARG_SPLIT_MODE) | | `-sm, --split-mode {none,layer,row}` | how to split the model across multiple GPUs, one of:<br/>- none: use one GPU only<br/>- layer (default): split layers and KV across GPUs<br/>- row: split rows across GPUs<br/>(env: LLAMA_ARG_SPLIT_MODE) |
| `-ts, --tensor-split N0,N1,N2,...` | fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1<br/>(env: LLAMA_ARG_TENSOR_SPLIT) | | `-ts, --tensor-split N0,N1,N2,...` | fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1<br/>(env: LLAMA_ARG_TENSOR_SPLIT) |
@ -120,7 +120,7 @@ For the ful list of features, please refer to [server's changelog](https://githu
| -------- | ----------- | | -------- | ----------- |
| `--samplers SAMPLERS` | samplers that will be used for generation in the order, separated by ';'<br/>(default: penalties;dry;top_n_sigma;top_k;typ_p;top_p;min_p;xtc;temperature) | | `--samplers SAMPLERS` | samplers that will be used for generation in the order, separated by ';'<br/>(default: penalties;dry;top_n_sigma;top_k;typ_p;top_p;min_p;xtc;temperature) |
| `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) | | `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) |
| `--sampling-seq, --sampler-seq SEQUENCE` | simplified sequence for samplers that will be used (default: edskypmxt) | | `--sampler-seq, --sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: edskypmxt) |
| `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) | | `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) |
| `--temp N` | temperature (default: 0.8) | | `--temp N` | temperature (default: 0.8) |
| `--top-k N` | top-k sampling (default: 40, 0 = disabled)<br/>(env: LLAMA_ARG_TOP_K) | | `--top-k N` | top-k sampling (default: 40, 0 = disabled)<br/>(env: LLAMA_ARG_TOP_K) |
@ -156,8 +156,8 @@ For the ful list of features, please refer to [server's changelog](https://githu
| Argument | Explanation | | Argument | Explanation |
| -------- | ----------- | | -------- | ----------- |
| `--ctx-checkpoints, --swa-checkpoints N` | max number of context checkpoints to create per slot (default: 8)[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)<br/>(env: LLAMA_ARG_CTX_CHECKPOINTS) | | `--ctx-checkpoints, --swa-checkpoints N` | max number of context checkpoints to create per slot (default: 8)[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)<br/>(env: LLAMA_ARG_CTX_CHECKPOINTS) |
| `--cache-ram, -cram N` | set the maximum cache size in MiB (default: 8192, -1 - no limit, 0 - disable)[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)<br/>(env: LLAMA_ARG_CACHE_RAM) | | `-cram, --cache-ram N` | set the maximum cache size in MiB (default: 8192, -1 - no limit, 0 - disable)[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)<br/>(env: LLAMA_ARG_CACHE_RAM) |
| `--kv-unified, -kvu` | use single unified KV buffer shared across all sequences (default: enabled if number of slots is auto)<br/>(env: LLAMA_ARG_KV_UNIFIED) | | `-kvu, --kv-unified` | use single unified KV buffer shared across all sequences (default: enabled if number of slots is auto)<br/>(env: LLAMA_ARG_KV_UNIFIED) |
| `--context-shift, --no-context-shift` | whether to use context shift on infinite text generation (default: disabled)<br/>(env: LLAMA_ARG_CONTEXT_SHIFT) | | `--context-shift, --no-context-shift` | whether to use context shift on infinite text generation (default: disabled)<br/>(env: LLAMA_ARG_CONTEXT_SHIFT) |
| `-r, --reverse-prompt PROMPT` | halt generation at PROMPT, return control in interactive mode<br/> | | `-r, --reverse-prompt PROMPT` | halt generation at PROMPT, return control in interactive mode<br/> |
| `-sp, --special` | special tokens output enabled (default: false) | | `-sp, --special` | special tokens output enabled (default: false) |
@ -172,9 +172,9 @@ For the ful list of features, please refer to [server's changelog](https://githu
| `--mmproj-offload, --no-mmproj-offload` | whether to enable GPU offloading for multimodal projector (default: enabled)<br/>(env: LLAMA_ARG_MMPROJ_OFFLOAD) | | `--mmproj-offload, --no-mmproj-offload` | whether to enable GPU offloading for multimodal projector (default: enabled)<br/>(env: LLAMA_ARG_MMPROJ_OFFLOAD) |
| `--image-min-tokens N` | minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)<br/>(env: LLAMA_ARG_IMAGE_MIN_TOKENS) | | `--image-min-tokens N` | minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)<br/>(env: LLAMA_ARG_IMAGE_MIN_TOKENS) |
| `--image-max-tokens N` | maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)<br/>(env: LLAMA_ARG_IMAGE_MAX_TOKENS) | | `--image-max-tokens N` | maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)<br/>(env: LLAMA_ARG_IMAGE_MAX_TOKENS) |
| `--override-tensor-draft, -otd <tensor name pattern>=<buffer type>,...` | override tensor buffer type for draft model | | `-otd, --override-tensor-draft <tensor name pattern>=<buffer type>,...` | override tensor buffer type for draft model |
| `--cpu-moe-draft, -cmoed` | keep all Mixture of Experts (MoE) weights in the CPU for the draft model<br/>(env: LLAMA_ARG_CPU_MOE_DRAFT) | | `-cmoed, --cpu-moe-draft` | keep all Mixture of Experts (MoE) weights in the CPU for the draft model<br/>(env: LLAMA_ARG_CPU_MOE_DRAFT) |
| `--n-cpu-moe-draft, -ncmoed N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model<br/>(env: LLAMA_ARG_N_CPU_MOE_DRAFT) | | `-ncmoed, --n-cpu-moe-draft N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model<br/>(env: LLAMA_ARG_N_CPU_MOE_DRAFT) |
| `-a, --alias STRING` | set alias for model name (to be used by REST API)<br/>(env: LLAMA_ARG_ALIAS) | | `-a, --alias STRING` | set alias for model name (to be used by REST API)<br/>(env: LLAMA_ARG_ALIAS) |
| `--host HOST` | ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: 127.0.0.1)<br/>(env: LLAMA_ARG_HOST) | | `--host HOST` | ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: 127.0.0.1)<br/>(env: LLAMA_ARG_HOST) |
| `--port PORT` | port to listen (default: 8080)<br/>(env: LLAMA_ARG_PORT) | | `--port PORT` | port to listen (default: 8080)<br/>(env: LLAMA_ARG_PORT) |
@ -184,7 +184,7 @@ For the ful list of features, please refer to [server's changelog](https://githu
| `--webui-config-file PATH` | JSON file that provides default WebUI settings (overrides WebUI defaults)<br/>(env: LLAMA_ARG_WEBUI_CONFIG_FILE) | | `--webui-config-file PATH` | JSON file that provides default WebUI settings (overrides WebUI defaults)<br/>(env: LLAMA_ARG_WEBUI_CONFIG_FILE) |
| `--webui, --no-webui` | whether to enable the Web UI (default: enabled)<br/>(env: LLAMA_ARG_WEBUI) | | `--webui, --no-webui` | whether to enable the Web UI (default: enabled)<br/>(env: LLAMA_ARG_WEBUI) |
| `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)<br/>(env: LLAMA_ARG_EMBEDDINGS) | | `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)<br/>(env: LLAMA_ARG_EMBEDDINGS) |
| `--reranking, --rerank` | enable reranking endpoint on server (default: disabled)<br/>(env: LLAMA_ARG_RERANKING) | | `--rerank, --reranking` | enable reranking endpoint on server (default: disabled)<br/>(env: LLAMA_ARG_RERANKING) |
| `--api-key KEY` | API key to use for authentication (default: none)<br/>(env: LLAMA_API_KEY) | | `--api-key KEY` | API key to use for authentication (default: none)<br/>(env: LLAMA_API_KEY) |
| `--api-key-file FNAME` | path to file containing API keys (default: none) | | `--api-key-file FNAME` | path to file containing API keys (default: none) |
| `--ssl-key-file FNAME` | path to file a PEM-encoded SSL private key<br/>(env: LLAMA_ARG_SSL_KEY_FILE) | | `--ssl-key-file FNAME` | path to file a PEM-encoded SSL private key<br/>(env: LLAMA_ARG_SSL_KEY_FILE) |
@ -212,7 +212,7 @@ For the ful list of features, please refer to [server's changelog](https://githu
| `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) | | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) |
| `-td, --threads-draft N` | number of threads to use during generation (default: same as --threads) | | `-td, --threads-draft N` | number of threads to use during generation (default: same as --threads) |
| `-tbd, --threads-batch-draft N` | number of threads to use during batch and prompt processing (default: same as --threads-draft) | | `-tbd, --threads-batch-draft N` | number of threads to use during batch and prompt processing (default: same as --threads-draft) |
| `--draft-max, --draft, --draft-n N` | number of tokens to draft for speculative decoding (default: 16)<br/>(env: LLAMA_ARG_DRAFT_MAX) | | `--draft, --draft-n, --draft-max N` | number of tokens to draft for speculative decoding (default: 16)<br/>(env: LLAMA_ARG_DRAFT_MAX) |
| `--draft-min, --draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)<br/>(env: LLAMA_ARG_DRAFT_MIN) | | `--draft-min, --draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)<br/>(env: LLAMA_ARG_DRAFT_MIN) |
| `--draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.8)<br/>(env: LLAMA_ARG_DRAFT_P_MIN) | | `--draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.8)<br/>(env: LLAMA_ARG_DRAFT_P_MIN) |
| `-cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)<br/>(env: LLAMA_ARG_CTX_SIZE_DRAFT) | | `-cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)<br/>(env: LLAMA_ARG_CTX_SIZE_DRAFT) |
@ -1443,6 +1443,12 @@ Example:
```ini ```ini
version = 1 version = 1
; (Optional) This section provides global settings shared across all presets.
; If the same key is defined in a specific preset, it will override the value in this global section.
[*]
c = 8192
n-gpu-layer = 8
; If the key corresponds to an existing model on the server, ; If the key corresponds to an existing model on the server,
; this will be used as the default config for that model ; this will be used as the default config for that model
[ggml-org/MY-MODEL-GGUF:Q8_0] [ggml-org/MY-MODEL-GGUF:Q8_0]
@ -1462,12 +1468,17 @@ model-draft = ./my-models/draft.gguf
model-draft = /Users/abc/my-models/draft.gguf model-draft = /Users/abc/my-models/draft.gguf
; If the key does NOT correspond to an existing model, ; If the key does NOT correspond to an existing model,
; you need to specify at least the model path ; you need to specify at least the model path or HF repo
[custom_model] [custom_model]
model = /Users/abc/my-awesome-model-Q4_K_M.gguf model = /Users/abc/my-awesome-model-Q4_K_M.gguf
``` ```
Note: some arguments are controlled by router (e.g., host, port, API key, HF repo, model alias). They will be removed or overwritten upload loading. Note: some arguments are controlled by router (e.g., host, port, API key, HF repo, model alias). They will be removed or overwritten upon loading.
The precedence rule for preset options is as follows:
1. **Command-line arguments** passed to `llama-server` (highest priority)
2. **Model-specific options** defined in the preset file (e.g. `[ggml-org/MY-MODEL...]`)
3. **Global options** defined in the preset file (`[*]`)
### Routing requests ### Routing requests

Binary file not shown.

View File

@ -1974,19 +1974,33 @@ struct server_context_impl {
if (!slot.can_split()) { if (!slot.can_split()) {
if (slot.task->n_tokens() > n_ubatch) { if (slot.task->n_tokens() > n_ubatch) {
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); send_error(slot,
string_format(
"input (%d tokens) is too large to process. increase the physical batch "
"size (current batch size: %d)",
slot.task->n_tokens(), n_ubatch),
ERROR_TYPE_SERVER);
slot.release(); slot.release();
continue; continue;
} }
if (slot.task->n_tokens() > slot.n_ctx) { if (slot.task->n_tokens() > slot.n_ctx) {
send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE); send_error(
slot,
string_format(
"input (%d tokens) is larger than the max context size (%d tokens). skipping",
slot.task->n_tokens(), slot.n_ctx),
ERROR_TYPE_EXCEED_CONTEXT_SIZE);
slot.release(); slot.release();
continue; continue;
} }
} else { } else {
if (slot.task->n_tokens() >= slot.n_ctx) { if (slot.task->n_tokens() >= slot.n_ctx) {
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE); send_error(slot,
string_format("request (%d tokens) exceeds the available context size (%d "
"tokens), try increasing it",
slot.task->n_tokens(), slot.n_ctx),
ERROR_TYPE_EXCEED_CONTEXT_SIZE);
slot.release(); slot.release();
continue; continue;
} }

View File

@ -82,154 +82,30 @@ static std::filesystem::path get_server_exec_path() {
#endif #endif
} }
struct local_model { static void unset_reserved_args(common_preset & preset, bool unset_model_args) {
std::string name; preset.unset_option("LLAMA_ARG_SSL_KEY_FILE");
std::string path; preset.unset_option("LLAMA_ARG_SSL_CERT_FILE");
std::string path_mmproj; preset.unset_option("LLAMA_API_KEY");
}; preset.unset_option("LLAMA_ARG_MODELS_DIR");
preset.unset_option("LLAMA_ARG_MODELS_MAX");
static std::vector<local_model> list_local_models(const std::string & dir) { preset.unset_option("LLAMA_ARG_MODELS_PRESET");
if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) { preset.unset_option("LLAMA_ARG_MODELS_AUTOLOAD");
throw std::runtime_error(string_format("error: '%s' does not exist or is not a directory\n", dir.c_str())); if (unset_model_args) {
} preset.unset_option("LLAMA_ARG_MODEL");
preset.unset_option("LLAMA_ARG_MMPROJ");
std::vector<local_model> models; preset.unset_option("LLAMA_ARG_HF_REPO");
auto scan_subdir = [&models](const std::string & subdir_path, const std::string & name) {
auto files = fs_list(subdir_path, false);
common_file_info model_file;
common_file_info first_shard_file;
common_file_info mmproj_file;
for (const auto & file : files) {
if (string_ends_with(file.name, ".gguf")) {
if (file.name.find("mmproj") != std::string::npos) {
mmproj_file = file;
} else if (file.name.find("-00001-of-") != std::string::npos) {
first_shard_file = file;
} else {
model_file = file;
}
}
}
// single file model
local_model model{
/* name */ name,
/* path */ first_shard_file.path.empty() ? model_file.path : first_shard_file.path,
/* path_mmproj */ mmproj_file.path // can be empty
};
if (!model.path.empty()) {
models.push_back(model);
}
};
auto files = fs_list(dir, true);
for (const auto & file : files) {
if (file.is_dir) {
scan_subdir(file.path, file.name);
} else if (string_ends_with(file.name, ".gguf")) {
// single file model
std::string name = file.name;
string_replace_all(name, ".gguf", "");
local_model model{
/* name */ name,
/* path */ file.path,
/* path_mmproj */ ""
};
models.push_back(model);
}
}
return models;
}
//
// server_presets
//
server_presets::server_presets(int argc, char ** argv, common_params & base_params, const std::string & presets_path)
: ctx_params(common_params_parser_init(base_params, LLAMA_EXAMPLE_SERVER)) {
if (!presets_path.empty()) {
presets = common_presets_load(presets_path, ctx_params);
SRV_INF("Loaded %zu presets from %s\n", presets.size(), presets_path.c_str());
}
// populate reserved args (will be appended by the router)
for (auto & opt : ctx_params.options) {
if (opt.env == nullptr) {
continue;
}
std::string env = opt.env;
if (env == "LLAMA_ARG_PORT" ||
env == "LLAMA_ARG_HOST" ||
env == "LLAMA_ARG_ALIAS" ||
env == "LLAMA_ARG_API_KEY" ||
env == "LLAMA_ARG_MODELS_DIR" ||
env == "LLAMA_ARG_MODELS_MAX" ||
env == "LLAMA_ARG_MODELS_PRESET" ||
env == "LLAMA_ARG_MODEL" ||
env == "LLAMA_ARG_MMPROJ" ||
env == "LLAMA_ARG_HF_REPO" ||
env == "LLAMA_ARG_NO_MODELS_AUTOLOAD" ||
env == "LLAMA_ARG_SSL_KEY_FILE" ||
env == "LLAMA_ARG_SSL_CERT_FILE") {
control_args[env] = opt;
}
}
// read base args from router's argv
common_params_to_map(argc, argv, LLAMA_EXAMPLE_SERVER, base_args);
// remove any router-controlled args from base_args
for (const auto & cargs : control_args) {
auto it = base_args.find(cargs.second);
if (it != base_args.end()) {
base_args.erase(it);
}
} }
} }
common_preset server_presets::get_preset(const std::string & name) { void server_model_meta::update_args(common_preset_context & ctx_preset, std::string bin_path) {
auto it = presets.find(name); // update params
if (it != presets.end()) { unset_reserved_args(preset, false);
return it->second; preset.set_option(ctx_preset, "LLAMA_ARG_HOST", CHILD_ADDR);
} preset.set_option(ctx_preset, "LLAMA_ARG_PORT", std::to_string(port));
return common_preset(); preset.set_option(ctx_preset, "LLAMA_ARG_ALIAS", name);
} // TODO: maybe validate preset before rendering ?
// render args
void server_presets::render_args(server_model_meta & meta) { args = preset.to_args(bin_path);
common_preset preset = meta.preset; // copy
// merging 3 kinds of args:
// 1. model-specific args (from preset)
// force removing control args if any
for (auto & cargs : control_args) {
if (preset.options.find(cargs.second) != preset.options.end()) {
SRV_WRN("Preset '%s' contains reserved arg '%s', removing it\n", preset.name.c_str(), cargs.second.args[0]);
preset.options.erase(cargs.second);
}
}
// 2. base args (from router)
// inherit from base args
for (const auto & [arg, value] : base_args) {
preset.options[arg] = value;
}
// 3. control args (from router)
// set control values
preset.options[control_args["LLAMA_ARG_HOST"]] = CHILD_ADDR;
preset.options[control_args["LLAMA_ARG_PORT"]] = std::to_string(meta.port);
preset.options[control_args["LLAMA_ARG_ALIAS"]] = meta.name;
if (meta.in_cache) {
preset.options[control_args["LLAMA_ARG_HF_REPO"]] = meta.name;
} else {
preset.options[control_args["LLAMA_ARG_MODEL"]] = meta.path;
if (!meta.path_mmproj.empty()) {
preset.options[control_args["LLAMA_ARG_MMPROJ"]] = meta.path_mmproj;
}
}
// disable SSL for child processes (HTTPS already handled by router)
preset.options[control_args["LLAMA_ARG_SSL_KEY_FILE"]] = "";
preset.options[control_args["LLAMA_ARG_SSL_CERT_FILE"]] = "";
meta.args = preset.to_args();
// add back the binary path at the front
meta.args.insert(meta.args.begin(), get_server_exec_path().string());
} }
// //
@ -240,20 +116,22 @@ server_models::server_models(
const common_params & params, const common_params & params,
int argc, int argc,
char ** argv, char ** argv,
char ** envp) : base_params(params), presets(argc, argv, base_params, params.models_preset) { char ** envp)
for (int i = 0; i < argc; i++) { : ctx_preset(LLAMA_EXAMPLE_SERVER),
base_args.push_back(std::string(argv[i])); base_params(params),
} base_preset(ctx_preset.load_from_args(argc, argv)) {
for (char ** env = envp; *env != nullptr; env++) { for (char ** env = envp; *env != nullptr; env++) {
base_env.push_back(std::string(*env)); base_env.push_back(std::string(*env));
} }
GGML_ASSERT(!base_args.empty()); // clean up base preset
unset_reserved_args(base_preset, true);
// set binary path // set binary path
try { try {
base_args[0] = get_server_exec_path().string(); bin_path = get_server_exec_path().string();
} catch (const std::exception & e) { } catch (const std::exception & e) {
bin_path = argv[0];
LOG_WRN("failed to get server executable path: %s\n", e.what()); LOG_WRN("failed to get server executable path: %s\n", e.what());
LOG_WRN("using original argv[0] as fallback: %s\n", base_args[0].c_str()); LOG_WRN("using original argv[0] as fallback: %s\n", argv[0]);
} }
load_models(); load_models();
} }
@ -262,7 +140,7 @@ void server_models::add_model(server_model_meta && meta) {
if (mapping.find(meta.name) != mapping.end()) { if (mapping.find(meta.name) != mapping.end()) {
throw std::runtime_error(string_format("model '%s' appears multiple times", meta.name.c_str())); throw std::runtime_error(string_format("model '%s' appears multiple times", meta.name.c_str()));
} }
presets.render_args(meta); // populate meta.args meta.update_args(ctx_preset, bin_path); // render args
std::string name = meta.name; std::string name = meta.name;
mapping[name] = instance_t{ mapping[name] = instance_t{
/* subproc */ std::make_shared<subprocess_s>(), /* subproc */ std::make_shared<subprocess_s>(),
@ -271,86 +149,62 @@ void server_models::add_model(server_model_meta && meta) {
}; };
} }
static std::vector<local_model> list_custom_path_models(server_presets & presets) {
// detect any custom-path models in presets
std::vector<local_model> custom_models;
for (auto & [model_name, preset] : presets.presets) {
local_model model;
model.name = model_name;
std::vector<common_arg> to_erase;
for (auto & [arg, value] : preset.options) {
std::string env(arg.env ? arg.env : "");
if (env == "LLAMA_ARG_MODEL") {
model.path = value;
to_erase.push_back(arg);
}
if (env == "LLAMA_ARG_MMPROJ") {
model.path_mmproj = value;
to_erase.push_back(arg);
}
}
for (auto & arg : to_erase) {
preset.options.erase(arg);
}
if (!model.name.empty() && !model.path.empty()) {
custom_models.push_back(model);
}
}
return custom_models;
}
// TODO: allow refreshing cached model list // TODO: allow refreshing cached model list
void server_models::load_models() { void server_models::load_models() {
// loading models from 3 sources: // loading models from 3 sources:
// 1. cached models // 1. cached models
auto cached_models = common_list_cached_models(); common_presets cached_models = ctx_preset.load_from_cache();
for (const auto & model : cached_models) { SRV_INF("Loaded %zu cached model presets\n", cached_models.size());
server_model_meta meta{ // 2. local models from --models-dir
/* preset */ presets.get_preset(model.to_string()), common_presets local_models;
/* name */ model.to_string(),
/* path */ model.manifest_path,
/* path_mmproj */ "", // auto-detected when loading
/* in_cache */ true,
/* port */ 0,
/* status */ SERVER_MODEL_STATUS_UNLOADED,
/* last_used */ 0,
/* args */ std::vector<std::string>(),
/* exit_code */ 0
};
add_model(std::move(meta));
}
// 2. local models specificed via --models-dir
if (!base_params.models_dir.empty()) { if (!base_params.models_dir.empty()) {
auto local_models = list_local_models(base_params.models_dir); local_models = ctx_preset.load_from_models_dir(base_params.models_dir);
for (const auto & model : local_models) { SRV_INF("Loaded %zu local model presets from %s\n", local_models.size(), base_params.models_dir.c_str());
if (mapping.find(model.name) != mapping.end()) { }
// already exists in cached models, skip // 3. custom-path models from presets
continue; common_preset global = {};
} common_presets custom_presets = {};
server_model_meta meta{ if (!base_params.models_preset.empty()) {
/* preset */ presets.get_preset(model.name), custom_presets = ctx_preset.load_from_ini(base_params.models_preset, global);
/* name */ model.name, SRV_INF("Loaded %zu custom model presets from %s\n", custom_presets.size(), base_params.models_preset.c_str());
/* path */ model.path, }
/* path_mmproj */ model.path_mmproj,
/* in_cache */ false, // cascade, apply global preset first
/* port */ 0, cached_models = ctx_preset.cascade(global, cached_models);
/* status */ SERVER_MODEL_STATUS_UNLOADED, local_models = ctx_preset.cascade(global, local_models);
/* last_used */ 0, custom_presets = ctx_preset.cascade(global, custom_presets);
/* args */ std::vector<std::string>(),
/* exit_code */ 0 // note: if a model exists in both cached and local, local takes precedence
}; common_presets final_presets;
add_model(std::move(meta)); for (const auto & [name, preset] : cached_models) {
final_presets[name] = preset;
}
for (const auto & [name, preset] : local_models) {
final_presets[name] = preset;
}
// process custom presets from INI
for (const auto & [name, custom] : custom_presets) {
if (final_presets.find(name) != final_presets.end()) {
// apply custom config if exists
common_preset & target = final_presets[name];
target.merge(custom);
} else {
// otherwise add directly
final_presets[name] = custom;
} }
} }
// 3. custom-path models specified in presets
auto custom_models = list_custom_path_models(presets); // server base preset from CLI args take highest precedence
for (const auto & model : custom_models) { for (auto & [name, preset] : final_presets) {
preset.merge(base_preset);
}
// convert presets to server_model_meta and add to mapping
for (const auto & preset : final_presets) {
server_model_meta meta{ server_model_meta meta{
/* preset */ presets.get_preset(model.name), /* preset */ preset.second,
/* name */ model.name, /* name */ preset.first,
/* path */ model.path,
/* path_mmproj */ model.path_mmproj,
/* in_cache */ false,
/* port */ 0, /* port */ 0,
/* status */ SERVER_MODEL_STATUS_UNLOADED, /* status */ SERVER_MODEL_STATUS_UNLOADED,
/* last_used */ 0, /* last_used */ 0,
@ -359,10 +213,18 @@ void server_models::load_models() {
}; };
add_model(std::move(meta)); add_model(std::move(meta));
} }
// log available models // log available models
SRV_INF("Available models (%zu) (*: custom preset)\n", mapping.size()); {
for (const auto & [name, inst] : mapping) { std::unordered_set<std::string> custom_names;
SRV_INF(" %c %s\n", inst.meta.preset.name.empty() ? ' ' : '*', name.c_str()); for (const auto & [name, preset] : custom_presets) {
custom_names.insert(name);
}
SRV_INF("Available models (%zu) (*: custom preset)\n", mapping.size());
for (const auto & [name, inst] : mapping) {
bool has_custom = custom_names.find(name) != custom_names.end();
SRV_INF(" %c %s\n", has_custom ? '*' : ' ', name.c_str());
}
} }
} }
@ -526,7 +388,7 @@ void server_models::load(const std::string & name) {
{ {
SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port); SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port);
presets.render_args(inst.meta); // update meta.args inst.meta.update_args(ctx_preset, bin_path); // render args
std::vector<std::string> child_args = inst.meta.args; // copy std::vector<std::string> child_args = inst.meta.args; // copy
std::vector<std::string> child_env = base_env; // copy std::vector<std::string> child_env = base_env; // copy
@ -877,7 +739,12 @@ void server_models_routes::init_routes() {
{"args", meta.args}, {"args", meta.args},
}; };
if (!meta.preset.name.empty()) { if (!meta.preset.name.empty()) {
status["preset"] = meta.preset.to_ini(); common_preset preset_copy = meta.preset;
unset_reserved_args(preset_copy, false);
preset_copy.unset_option("LLAMA_ARG_HOST");
preset_copy.unset_option("LLAMA_ARG_PORT");
preset_copy.unset_option("LLAMA_ARG_ALIAS");
status["preset"] = preset_copy.to_ini();
} }
if (meta.is_failed()) { if (meta.is_failed()) {
status["exit_code"] = meta.exit_code; status["exit_code"] = meta.exit_code;
@ -888,8 +755,6 @@ void server_models_routes::init_routes() {
{"object", "model"}, // for OAI-compat {"object", "model"}, // for OAI-compat
{"owned_by", "llamacpp"}, // for OAI-compat {"owned_by", "llamacpp"}, // for OAI-compat
{"created", t}, // for OAI-compat {"created", t}, // for OAI-compat
{"in_cache", meta.in_cache},
{"path", meta.path},
{"status", status}, {"status", status},
// TODO: add other fields, may require reading GGUF metadata // TODO: add other fields, may require reading GGUF metadata
}); });

View File

@ -51,9 +51,6 @@ static std::string server_model_status_to_string(server_model_status status) {
struct server_model_meta { struct server_model_meta {
common_preset preset; common_preset preset;
std::string name; std::string name;
std::string path;
std::string path_mmproj; // only available if in_cache=false
bool in_cache = false; // if true, use -hf; use -m otherwise
int port = 0; int port = 0;
server_model_status status = SERVER_MODEL_STATUS_UNLOADED; server_model_status status = SERVER_MODEL_STATUS_UNLOADED;
int64_t last_used = 0; // for LRU unloading int64_t last_used = 0; // for LRU unloading
@ -67,19 +64,8 @@ struct server_model_meta {
bool is_failed() const { bool is_failed() const {
return status == SERVER_MODEL_STATUS_UNLOADED && exit_code != 0; return status == SERVER_MODEL_STATUS_UNLOADED && exit_code != 0;
} }
};
// the server_presets struct holds the presets read from presets.ini void update_args(common_preset_context & ctx_presets, std::string bin_path);
// as well as base args from the router server
struct server_presets {
common_presets presets;
common_params_context ctx_params;
std::map<common_arg, std::string> base_args;
std::map<std::string, common_arg> control_args; // args reserved for server control
server_presets(int argc, char ** argv, common_params & base_params, const std::string & models_dir);
common_preset get_preset(const std::string & name);
void render_args(server_model_meta & meta);
}; };
struct subprocess_s; struct subprocess_s;
@ -97,11 +83,12 @@ private:
std::condition_variable cv; std::condition_variable cv;
std::map<std::string, instance_t> mapping; std::map<std::string, instance_t> mapping;
common_params base_params; common_preset_context ctx_preset;
std::vector<std::string> base_args;
std::vector<std::string> base_env;
server_presets presets; common_params base_params;
std::string bin_path;
std::vector<std::string> base_env;
common_preset base_preset; // base preset from llama-server CLI args
void update_meta(const std::string & name, const server_model_meta & meta); void update_meta(const std::string & name, const server_model_meta & meta);

View File

@ -11,6 +11,8 @@ flowchart TB
C_Screen["ChatScreen"] C_Screen["ChatScreen"]
C_Form["ChatForm"] C_Form["ChatForm"]
C_Messages["ChatMessages"] C_Messages["ChatMessages"]
C_Message["ChatMessage"]
C_MessageEditForm["ChatMessageEditForm"]
C_ModelsSelector["ModelsSelector"] C_ModelsSelector["ModelsSelector"]
C_Settings["ChatSettings"] C_Settings["ChatSettings"]
end end
@ -54,7 +56,9 @@ flowchart TB
%% Component hierarchy %% Component hierarchy
C_Screen --> C_Form & C_Messages & C_Settings C_Screen --> C_Form & C_Messages & C_Settings
C_Form & C_Messages --> C_ModelsSelector C_Messages --> C_Message
C_Message --> C_MessageEditForm
C_Form & C_MessageEditForm --> C_ModelsSelector
%% Components → Hooks → Stores %% Components → Hooks → Stores
C_Form & C_Messages --> H1 & H2 C_Form & C_Messages --> H1 & H2
@ -93,7 +97,7 @@ flowchart TB
classDef apiStyle fill:#e3f2fd,stroke:#1565c0,stroke-width:2px classDef apiStyle fill:#e3f2fd,stroke:#1565c0,stroke-width:2px
class R1,R2,RL routeStyle class R1,R2,RL routeStyle
class C_Sidebar,C_Screen,C_Form,C_Messages,C_ModelsSelector,C_Settings componentStyle class C_Sidebar,C_Screen,C_Form,C_Messages,C_Message,C_MessageEditForm,C_ModelsSelector,C_Settings componentStyle
class H1,H2 hookStyle class H1,H2 hookStyle
class S1,S2,S3,S4,S5 storeStyle class S1,S2,S3,S4,S5 storeStyle
class SV1,SV2,SV3,SV4,SV5 serviceStyle class SV1,SV2,SV3,SV4,SV5 serviceStyle

View File

@ -16,6 +16,8 @@ end
C_Form["ChatForm"] C_Form["ChatForm"]
C_Messages["ChatMessages"] C_Messages["ChatMessages"]
C_Message["ChatMessage"] C_Message["ChatMessage"]
C_MessageUser["ChatMessageUser"]
C_MessageEditForm["ChatMessageEditForm"]
C_Attach["ChatAttachments"] C_Attach["ChatAttachments"]
C_ModelsSelector["ModelsSelector"] C_ModelsSelector["ModelsSelector"]
C_Settings["ChatSettings"] C_Settings["ChatSettings"]
@ -38,7 +40,7 @@ end
S1Error["<b>Error Handling:</b><br/>showErrorDialog()<br/>dismissErrorDialog()<br/>isAbortError()"] S1Error["<b>Error Handling:</b><br/>showErrorDialog()<br/>dismissErrorDialog()<br/>isAbortError()"]
S1Msg["<b>Message Operations:</b><br/>addMessage()<br/>sendMessage()<br/>updateMessage()<br/>deleteMessage()<br/>getDeletionInfo()"] S1Msg["<b>Message Operations:</b><br/>addMessage()<br/>sendMessage()<br/>updateMessage()<br/>deleteMessage()<br/>getDeletionInfo()"]
S1Regen["<b>Regeneration:</b><br/>regenerateMessage()<br/>regenerateMessageWithBranching()<br/>continueAssistantMessage()"] S1Regen["<b>Regeneration:</b><br/>regenerateMessage()<br/>regenerateMessageWithBranching()<br/>continueAssistantMessage()"]
S1Edit["<b>Editing:</b><br/>editAssistantMessage()<br/>editUserMessagePreserveResponses()<br/>editMessageWithBranching()"] S1Edit["<b>Editing:</b><br/>editAssistantMessage()<br/>editUserMessagePreserveResponses()<br/>editMessageWithBranching()<br/>clearEditMode()<br/>isEditModeActive()<br/>getAddFilesHandler()<br/>setEditModeActive()"]
S1Utils["<b>Utilities:</b><br/>getApiOptions()<br/>parseTimingData()<br/>getOrCreateAbortController()<br/>getConversationModel()"] S1Utils["<b>Utilities:</b><br/>getApiOptions()<br/>parseTimingData()<br/>getOrCreateAbortController()<br/>getConversationModel()"]
end end
subgraph S2["conversationsStore"] subgraph S2["conversationsStore"]
@ -88,6 +90,10 @@ end
RE7["getChatStreaming()"] RE7["getChatStreaming()"]
RE8["getAllLoadingChats()"] RE8["getAllLoadingChats()"]
RE9["getAllStreamingChats()"] RE9["getAllStreamingChats()"]
RE9a["isEditModeActive()"]
RE9b["getAddFilesHandler()"]
RE9c["setEditModeActive()"]
RE9d["clearEditMode()"]
end end
subgraph ConvExports["conversationsStore"] subgraph ConvExports["conversationsStore"]
RE10["conversations()"] RE10["conversations()"]
@ -182,7 +188,10 @@ end
%% Component hierarchy %% Component hierarchy
C_Screen --> C_Form & C_Messages & C_Settings C_Screen --> C_Form & C_Messages & C_Settings
C_Messages --> C_Message C_Messages --> C_Message
C_Message --> C_ModelsSelector C_Message --> C_MessageUser
C_MessageUser --> C_MessageEditForm
C_MessageEditForm --> C_ModelsSelector
C_MessageEditForm --> C_Attach
C_Form --> C_ModelsSelector C_Form --> C_ModelsSelector
C_Form --> C_Attach C_Form --> C_Attach
C_Message --> C_Attach C_Message --> C_Attach
@ -190,6 +199,7 @@ end
%% Components use Hooks %% Components use Hooks
C_Form --> H1 C_Form --> H1
C_Message --> H1 & H2 C_Message --> H1 & H2
C_MessageEditForm --> H1
C_Screen --> H2 C_Screen --> H2
%% Hooks use Stores %% Hooks use Stores
@ -244,7 +254,7 @@ end
classDef apiStyle fill:#e3f2fd,stroke:#1565c0,stroke-width:2px classDef apiStyle fill:#e3f2fd,stroke:#1565c0,stroke-width:2px
class R1,R2,RL routeStyle class R1,R2,RL routeStyle
class C_Sidebar,C_Screen,C_Form,C_Messages,C_Message componentStyle class C_Sidebar,C_Screen,C_Form,C_Messages,C_Message,C_MessageUser,C_MessageEditForm componentStyle
class C_ModelsSelector,C_Settings componentStyle class C_ModelsSelector,C_Settings componentStyle
class C_Attach componentStyle class C_Attach componentStyle
class H1,H2,H3 methodStyle class H1,H2,H3 methodStyle

View File

@ -25,7 +25,7 @@
"@chromatic-com/storybook": "^4.1.2", "@chromatic-com/storybook": "^4.1.2",
"@eslint/compat": "^1.2.5", "@eslint/compat": "^1.2.5",
"@eslint/js": "^9.18.0", "@eslint/js": "^9.18.0",
"@internationalized/date": "^3.8.2", "@internationalized/date": "^3.10.1",
"@lucide/svelte": "^0.515.0", "@lucide/svelte": "^0.515.0",
"@playwright/test": "^1.49.1", "@playwright/test": "^1.49.1",
"@storybook/addon-a11y": "^10.0.7", "@storybook/addon-a11y": "^10.0.7",
@ -862,9 +862,9 @@
} }
}, },
"node_modules/@internationalized/date": { "node_modules/@internationalized/date": {
"version": "3.8.2", "version": "3.10.1",
"resolved": "https://registry.npmjs.org/@internationalized/date/-/date-3.8.2.tgz", "resolved": "https://registry.npmjs.org/@internationalized/date/-/date-3.10.1.tgz",
"integrity": "sha512-/wENk7CbvLbkUvX1tu0mwq49CVkkWpkXubGel6birjRPyo6uQ4nQpnq5xZu823zRCwwn82zgHrvgF1vZyvmVgA==", "integrity": "sha512-oJrXtQiAXLvT9clCf1K4kxp3eKsQhIaZqxEyowkBcsvZDdZkbWrVmnGknxs5flTD0VGsxrxKgBCZty1EzoiMzA==",
"dev": true, "dev": true,
"license": "Apache-2.0", "license": "Apache-2.0",
"dependencies": { "dependencies": {

View File

@ -26,7 +26,7 @@
"@chromatic-com/storybook": "^4.1.2", "@chromatic-com/storybook": "^4.1.2",
"@eslint/compat": "^1.2.5", "@eslint/compat": "^1.2.5",
"@eslint/js": "^9.18.0", "@eslint/js": "^9.18.0",
"@internationalized/date": "^3.8.2", "@internationalized/date": "^3.10.1",
"@lucide/svelte": "^0.515.0", "@lucide/svelte": "^0.515.0",
"@playwright/test": "^1.49.1", "@playwright/test": "^1.49.1",
"@storybook/addon-a11y": "^10.0.7", "@storybook/addon-a11y": "^10.0.7",

View File

@ -8,6 +8,7 @@
ChatFormTextarea ChatFormTextarea
} from '$lib/components/app'; } from '$lib/components/app';
import { INPUT_CLASSES } from '$lib/constants/input-classes'; import { INPUT_CLASSES } from '$lib/constants/input-classes';
import { SETTING_CONFIG_DEFAULT } from '$lib/constants/settings-config';
import { config } from '$lib/stores/settings.svelte'; import { config } from '$lib/stores/settings.svelte';
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte'; import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
import { isRouterMode } from '$lib/stores/server.svelte'; import { isRouterMode } from '$lib/stores/server.svelte';
@ -66,7 +67,7 @@
let message = $state(''); let message = $state('');
let pasteLongTextToFileLength = $derived.by(() => { let pasteLongTextToFileLength = $derived.by(() => {
const n = Number(currentConfig.pasteLongTextToFileLen); const n = Number(currentConfig.pasteLongTextToFileLen);
return Number.isNaN(n) ? 2500 : n; return Number.isNaN(n) ? Number(SETTING_CONFIG_DEFAULT.pasteLongTextToFileLen) : n;
}); });
let previousIsLoading = $state(isLoading); let previousIsLoading = $state(isLoading);
let recordingSupported = $state(false); let recordingSupported = $state(false);

View File

@ -12,13 +12,21 @@
onCopy?: (message: DatabaseMessage) => void; onCopy?: (message: DatabaseMessage) => void;
onContinueAssistantMessage?: (message: DatabaseMessage) => void; onContinueAssistantMessage?: (message: DatabaseMessage) => void;
onDelete?: (message: DatabaseMessage) => void; onDelete?: (message: DatabaseMessage) => void;
onEditWithBranching?: (message: DatabaseMessage, newContent: string) => void; onEditWithBranching?: (
message: DatabaseMessage,
newContent: string,
newExtras?: DatabaseMessageExtra[]
) => void;
onEditWithReplacement?: ( onEditWithReplacement?: (
message: DatabaseMessage, message: DatabaseMessage,
newContent: string, newContent: string,
shouldBranch: boolean shouldBranch: boolean
) => void; ) => void;
onEditUserMessagePreserveResponses?: (message: DatabaseMessage, newContent: string) => void; onEditUserMessagePreserveResponses?: (
message: DatabaseMessage,
newContent: string,
newExtras?: DatabaseMessageExtra[]
) => void;
onNavigateToSibling?: (siblingId: string) => void; onNavigateToSibling?: (siblingId: string) => void;
onRegenerateWithBranching?: (message: DatabaseMessage, modelOverride?: string) => void; onRegenerateWithBranching?: (message: DatabaseMessage, modelOverride?: string) => void;
siblingInfo?: ChatMessageSiblingInfo | null; siblingInfo?: ChatMessageSiblingInfo | null;
@ -45,6 +53,8 @@
messageTypes: string[]; messageTypes: string[];
} | null>(null); } | null>(null);
let editedContent = $state(message.content); let editedContent = $state(message.content);
let editedExtras = $state<DatabaseMessageExtra[]>(message.extra ? [...message.extra] : []);
let editedUploadedFiles = $state<ChatUploadedFile[]>([]);
let isEditing = $state(false); let isEditing = $state(false);
let showDeleteDialog = $state(false); let showDeleteDialog = $state(false);
let shouldBranchAfterEdit = $state(false); let shouldBranchAfterEdit = $state(false);
@ -85,6 +95,16 @@
function handleCancelEdit() { function handleCancelEdit() {
isEditing = false; isEditing = false;
editedContent = message.content; editedContent = message.content;
editedExtras = message.extra ? [...message.extra] : [];
editedUploadedFiles = [];
}
function handleEditedExtrasChange(extras: DatabaseMessageExtra[]) {
editedExtras = extras;
}
function handleEditedUploadedFilesChange(files: ChatUploadedFile[]) {
editedUploadedFiles = files;
} }
async function handleCopy() { async function handleCopy() {
@ -107,6 +127,8 @@
function handleEdit() { function handleEdit() {
isEditing = true; isEditing = true;
editedContent = message.content; editedContent = message.content;
editedExtras = message.extra ? [...message.extra] : [];
editedUploadedFiles = [];
setTimeout(() => { setTimeout(() => {
if (textareaElement) { if (textareaElement) {
@ -143,9 +165,10 @@
onContinueAssistantMessage?.(message); onContinueAssistantMessage?.(message);
} }
function handleSaveEdit() { async function handleSaveEdit() {
if (message.role === 'user' || message.role === 'system') { if (message.role === 'user' || message.role === 'system') {
onEditWithBranching?.(message, editedContent.trim()); const finalExtras = await getMergedExtras();
onEditWithBranching?.(message, editedContent.trim(), finalExtras);
} else { } else {
// For assistant messages, preserve exact content including trailing whitespace // For assistant messages, preserve exact content including trailing whitespace
// This is important for the Continue feature to work properly // This is important for the Continue feature to work properly
@ -154,15 +177,30 @@
isEditing = false; isEditing = false;
shouldBranchAfterEdit = false; shouldBranchAfterEdit = false;
editedUploadedFiles = [];
} }
function handleSaveEditOnly() { async function handleSaveEditOnly() {
if (message.role === 'user') { if (message.role === 'user') {
// For user messages, trim to avoid accidental whitespace // For user messages, trim to avoid accidental whitespace
onEditUserMessagePreserveResponses?.(message, editedContent.trim()); const finalExtras = await getMergedExtras();
onEditUserMessagePreserveResponses?.(message, editedContent.trim(), finalExtras);
} }
isEditing = false; isEditing = false;
editedUploadedFiles = [];
}
async function getMergedExtras(): Promise<DatabaseMessageExtra[]> {
if (editedUploadedFiles.length === 0) {
return editedExtras;
}
const { parseFilesToMessageExtras } = await import('$lib/utils/browser-only');
const result = await parseFilesToMessageExtras(editedUploadedFiles);
const newExtras = result?.extras || [];
return [...editedExtras, ...newExtras];
} }
function handleShowDeleteDialogChange(show: boolean) { function handleShowDeleteDialogChange(show: boolean) {
@ -197,6 +235,8 @@
class={className} class={className}
{deletionInfo} {deletionInfo}
{editedContent} {editedContent}
{editedExtras}
{editedUploadedFiles}
{isEditing} {isEditing}
{message} {message}
onCancelEdit={handleCancelEdit} onCancelEdit={handleCancelEdit}
@ -206,6 +246,8 @@
onEdit={handleEdit} onEdit={handleEdit}
onEditKeydown={handleEditKeydown} onEditKeydown={handleEditKeydown}
onEditedContentChange={handleEditedContentChange} onEditedContentChange={handleEditedContentChange}
onEditedExtrasChange={handleEditedExtrasChange}
onEditedUploadedFilesChange={handleEditedUploadedFilesChange}
{onNavigateToSibling} {onNavigateToSibling}
onSaveEdit={handleSaveEdit} onSaveEdit={handleSaveEdit}
onSaveEditOnly={handleSaveEditOnly} onSaveEditOnly={handleSaveEditOnly}

View File

@ -0,0 +1,391 @@
<script lang="ts">
import { X, ArrowUp, Paperclip, AlertTriangle } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import { Switch } from '$lib/components/ui/switch';
import { ChatAttachmentsList, DialogConfirmation, ModelsSelector } from '$lib/components/app';
import { INPUT_CLASSES } from '$lib/constants/input-classes';
import { SETTING_CONFIG_DEFAULT } from '$lib/constants/settings-config';
import { AttachmentType, FileTypeCategory, MimeTypeText } from '$lib/enums';
import { config } from '$lib/stores/settings.svelte';
import { useModelChangeValidation } from '$lib/hooks/use-model-change-validation.svelte';
import { setEditModeActive, clearEditMode } from '$lib/stores/chat.svelte';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { modelsStore } from '$lib/stores/models.svelte';
import { isRouterMode } from '$lib/stores/server.svelte';
import {
autoResizeTextarea,
getFileTypeCategory,
getFileTypeCategoryByExtension,
parseClipboardContent
} from '$lib/utils';
interface Props {
messageId: string;
editedContent: string;
editedExtras?: DatabaseMessageExtra[];
editedUploadedFiles?: ChatUploadedFile[];
originalContent: string;
originalExtras?: DatabaseMessageExtra[];
showSaveOnlyOption?: boolean;
onCancelEdit: () => void;
onSaveEdit: () => void;
onSaveEditOnly?: () => void;
onEditKeydown: (event: KeyboardEvent) => void;
onEditedContentChange: (content: string) => void;
onEditedExtrasChange?: (extras: DatabaseMessageExtra[]) => void;
onEditedUploadedFilesChange?: (files: ChatUploadedFile[]) => void;
textareaElement?: HTMLTextAreaElement;
}
let {
messageId,
editedContent,
editedExtras = [],
editedUploadedFiles = [],
originalContent,
originalExtras = [],
showSaveOnlyOption = false,
onCancelEdit,
onSaveEdit,
onSaveEditOnly,
onEditKeydown,
onEditedContentChange,
onEditedExtrasChange,
onEditedUploadedFilesChange,
textareaElement = $bindable()
}: Props = $props();
let fileInputElement: HTMLInputElement | undefined = $state();
let saveWithoutRegenerate = $state(false);
let showDiscardDialog = $state(false);
let isRouter = $derived(isRouterMode());
let currentConfig = $derived(config());
let pasteLongTextToFileLength = $derived.by(() => {
const n = Number(currentConfig.pasteLongTextToFileLen);
return Number.isNaN(n) ? Number(SETTING_CONFIG_DEFAULT.pasteLongTextToFileLen) : n;
});
let hasUnsavedChanges = $derived.by(() => {
if (editedContent !== originalContent) return true;
if (editedUploadedFiles.length > 0) return true;
const extrasChanged =
editedExtras.length !== originalExtras.length ||
editedExtras.some((extra, i) => extra !== originalExtras[i]);
if (extrasChanged) return true;
return false;
});
let hasAttachments = $derived(
(editedExtras && editedExtras.length > 0) ||
(editedUploadedFiles && editedUploadedFiles.length > 0)
);
let canSubmit = $derived(editedContent.trim().length > 0 || hasAttachments);
function getEditedAttachmentsModalities(): ModelModalities {
const modalities: ModelModalities = { vision: false, audio: false };
for (const extra of editedExtras) {
if (extra.type === AttachmentType.IMAGE) {
modalities.vision = true;
}
if (
extra.type === AttachmentType.PDF &&
'processedAsImages' in extra &&
extra.processedAsImages
) {
modalities.vision = true;
}
if (extra.type === AttachmentType.AUDIO) {
modalities.audio = true;
}
}
for (const file of editedUploadedFiles) {
const category = getFileTypeCategory(file.type) || getFileTypeCategoryByExtension(file.name);
if (category === FileTypeCategory.IMAGE) {
modalities.vision = true;
}
if (category === FileTypeCategory.AUDIO) {
modalities.audio = true;
}
}
return modalities;
}
function getRequiredModalities(): ModelModalities {
const beforeModalities = conversationsStore.getModalitiesUpToMessage(messageId);
const editedModalities = getEditedAttachmentsModalities();
return {
vision: beforeModalities.vision || editedModalities.vision,
audio: beforeModalities.audio || editedModalities.audio
};
}
const { handleModelChange } = useModelChangeValidation({
getRequiredModalities,
onValidationFailure: async (previousModelId) => {
if (previousModelId) {
await modelsStore.selectModelById(previousModelId);
}
}
});
function handleFileInputChange(event: Event) {
const input = event.target as HTMLInputElement;
if (!input.files || input.files.length === 0) return;
const files = Array.from(input.files);
processNewFiles(files);
input.value = '';
}
function handleGlobalKeydown(event: KeyboardEvent) {
if (event.key === 'Escape') {
event.preventDefault();
attemptCancel();
}
}
function attemptCancel() {
if (hasUnsavedChanges) {
showDiscardDialog = true;
} else {
onCancelEdit();
}
}
function handleRemoveExistingAttachment(index: number) {
if (!onEditedExtrasChange) return;
const newExtras = [...editedExtras];
newExtras.splice(index, 1);
onEditedExtrasChange(newExtras);
}
function handleRemoveUploadedFile(fileId: string) {
if (!onEditedUploadedFilesChange) return;
const newFiles = editedUploadedFiles.filter((f) => f.id !== fileId);
onEditedUploadedFilesChange(newFiles);
}
function handleSubmit() {
if (!canSubmit) return;
if (saveWithoutRegenerate && onSaveEditOnly) {
onSaveEditOnly();
} else {
onSaveEdit();
}
saveWithoutRegenerate = false;
}
async function processNewFiles(files: File[]) {
if (!onEditedUploadedFilesChange) return;
const { processFilesToChatUploaded } = await import('$lib/utils/browser-only');
const processed = await processFilesToChatUploaded(files);
onEditedUploadedFilesChange([...editedUploadedFiles, ...processed]);
}
function handlePaste(event: ClipboardEvent) {
if (!event.clipboardData) return;
const files = Array.from(event.clipboardData.items)
.filter((item) => item.kind === 'file')
.map((item) => item.getAsFile())
.filter((file): file is File => file !== null);
if (files.length > 0) {
event.preventDefault();
processNewFiles(files);
return;
}
const text = event.clipboardData.getData(MimeTypeText.PLAIN);
if (text.startsWith('"')) {
const parsed = parseClipboardContent(text);
if (parsed.textAttachments.length > 0) {
event.preventDefault();
onEditedContentChange(parsed.message);
const attachmentFiles = parsed.textAttachments.map(
(att) =>
new File([att.content], att.name, {
type: MimeTypeText.PLAIN
})
);
processNewFiles(attachmentFiles);
setTimeout(() => {
textareaElement?.focus();
}, 10);
return;
}
}
if (
text.length > 0 &&
pasteLongTextToFileLength > 0 &&
text.length > pasteLongTextToFileLength
) {
event.preventDefault();
const textFile = new File([text], 'Pasted', {
type: MimeTypeText.PLAIN
});
processNewFiles([textFile]);
}
}
$effect(() => {
if (textareaElement) {
autoResizeTextarea(textareaElement);
}
});
$effect(() => {
setEditModeActive(processNewFiles);
return () => {
clearEditMode();
};
});
</script>
<svelte:window onkeydown={handleGlobalKeydown} />
<input
bind:this={fileInputElement}
type="file"
multiple
class="hidden"
onchange={handleFileInputChange}
/>
<div
class="{INPUT_CLASSES} w-full max-w-[80%] overflow-hidden rounded-3xl backdrop-blur-md"
data-slot="edit-form"
>
<ChatAttachmentsList
attachments={editedExtras}
uploadedFiles={editedUploadedFiles}
readonly={false}
onFileRemove={(fileId) => {
if (fileId.startsWith('attachment-')) {
const index = parseInt(fileId.replace('attachment-', ''), 10);
if (!isNaN(index) && index >= 0 && index < editedExtras.length) {
handleRemoveExistingAttachment(index);
}
} else {
handleRemoveUploadedFile(fileId);
}
}}
limitToSingleRow
class="py-5"
style="scroll-padding: 1rem;"
/>
<div class="relative min-h-[48px] px-5 py-3">
<textarea
bind:this={textareaElement}
bind:value={editedContent}
class="field-sizing-content max-h-80 min-h-10 w-full resize-none bg-transparent text-sm outline-none"
onkeydown={onEditKeydown}
oninput={(e) => {
autoResizeTextarea(e.currentTarget);
onEditedContentChange(e.currentTarget.value);
}}
onpaste={handlePaste}
placeholder="Edit your message..."
></textarea>
<div class="flex w-full items-center gap-3" style="container-type: inline-size">
<Button
class="h-8 w-8 shrink-0 rounded-full bg-transparent p-0 text-muted-foreground hover:bg-foreground/10 hover:text-foreground"
onclick={() => fileInputElement?.click()}
type="button"
title="Add attachment"
>
<span class="sr-only">Attach files</span>
<Paperclip class="h-4 w-4" />
</Button>
<div class="flex-1"></div>
{#if isRouter}
<ModelsSelector
forceForegroundText={true}
useGlobalSelection={true}
onModelChange={handleModelChange}
/>
{/if}
<Button
class="h-8 w-8 shrink-0 rounded-full p-0"
onclick={handleSubmit}
disabled={!canSubmit}
type="button"
title={saveWithoutRegenerate ? 'Save changes' : 'Send and regenerate'}
>
<span class="sr-only">{saveWithoutRegenerate ? 'Save' : 'Send'}</span>
<ArrowUp class="h-5 w-5" />
</Button>
</div>
</div>
</div>
<div class="mt-2 flex w-full max-w-[80%] items-center justify-between">
{#if showSaveOnlyOption && onSaveEditOnly}
<div class="flex items-center gap-2">
<Switch id="save-only-switch" bind:checked={saveWithoutRegenerate} class="scale-75" />
<label for="save-only-switch" class="cursor-pointer text-xs text-muted-foreground">
Update without re-sending
</label>
</div>
{:else}
<div></div>
{/if}
<Button class="h-7 px-3 text-xs" onclick={attemptCancel} size="sm" variant="ghost">
<X class="mr-1 h-3 w-3" />
Cancel
</Button>
</div>
<DialogConfirmation
bind:open={showDiscardDialog}
title="Discard changes?"
description="You have unsaved changes. Are you sure you want to discard them?"
confirmText="Discard"
cancelText="Keep editing"
variant="destructive"
icon={AlertTriangle}
onConfirm={onCancelEdit}
onCancel={() => (showDiscardDialog = false)}
/>

View File

@ -1,18 +1,17 @@
<script lang="ts"> <script lang="ts">
import { Check, X, Send } from '@lucide/svelte';
import { Card } from '$lib/components/ui/card'; import { Card } from '$lib/components/ui/card';
import { Button } from '$lib/components/ui/button';
import { ChatAttachmentsList, MarkdownContent } from '$lib/components/app'; import { ChatAttachmentsList, MarkdownContent } from '$lib/components/app';
import { INPUT_CLASSES } from '$lib/constants/input-classes';
import { config } from '$lib/stores/settings.svelte'; import { config } from '$lib/stores/settings.svelte';
import { autoResizeTextarea } from '$lib/utils';
import ChatMessageActions from './ChatMessageActions.svelte'; import ChatMessageActions from './ChatMessageActions.svelte';
import ChatMessageEditForm from './ChatMessageEditForm.svelte';
interface Props { interface Props {
class?: string; class?: string;
message: DatabaseMessage; message: DatabaseMessage;
isEditing: boolean; isEditing: boolean;
editedContent: string; editedContent: string;
editedExtras?: DatabaseMessageExtra[];
editedUploadedFiles?: ChatUploadedFile[];
siblingInfo?: ChatMessageSiblingInfo | null; siblingInfo?: ChatMessageSiblingInfo | null;
showDeleteDialog: boolean; showDeleteDialog: boolean;
deletionInfo: { deletionInfo: {
@ -26,6 +25,8 @@
onSaveEditOnly?: () => void; onSaveEditOnly?: () => void;
onEditKeydown: (event: KeyboardEvent) => void; onEditKeydown: (event: KeyboardEvent) => void;
onEditedContentChange: (content: string) => void; onEditedContentChange: (content: string) => void;
onEditedExtrasChange?: (extras: DatabaseMessageExtra[]) => void;
onEditedUploadedFilesChange?: (files: ChatUploadedFile[]) => void;
onCopy: () => void; onCopy: () => void;
onEdit: () => void; onEdit: () => void;
onDelete: () => void; onDelete: () => void;
@ -40,6 +41,8 @@
message, message,
isEditing, isEditing,
editedContent, editedContent,
editedExtras = [],
editedUploadedFiles = [],
siblingInfo = null, siblingInfo = null,
showDeleteDialog, showDeleteDialog,
deletionInfo, deletionInfo,
@ -48,6 +51,8 @@
onSaveEditOnly, onSaveEditOnly,
onEditKeydown, onEditKeydown,
onEditedContentChange, onEditedContentChange,
onEditedExtrasChange,
onEditedUploadedFilesChange,
onCopy, onCopy,
onEdit, onEdit,
onDelete, onDelete,
@ -61,12 +66,6 @@
let messageElement: HTMLElement | undefined = $state(); let messageElement: HTMLElement | undefined = $state();
const currentConfig = config(); const currentConfig = config();
$effect(() => {
if (isEditing && textareaElement) {
autoResizeTextarea(textareaElement);
}
});
$effect(() => { $effect(() => {
if (!messageElement || !message.content.trim()) return; if (!messageElement || !message.content.trim()) return;
@ -98,44 +97,23 @@
role="group" role="group"
> >
{#if isEditing} {#if isEditing}
<div class="w-full max-w-[80%]"> <ChatMessageEditForm
<textarea bind:textareaElement
bind:this={textareaElement} messageId={message.id}
bind:value={editedContent} {editedContent}
class="min-h-[60px] w-full resize-none rounded-2xl px-3 py-2 text-sm {INPUT_CLASSES}" {editedExtras}
onkeydown={onEditKeydown} {editedUploadedFiles}
oninput={(e) => { originalContent={message.content}
autoResizeTextarea(e.currentTarget); originalExtras={message.extra}
onEditedContentChange(e.currentTarget.value); showSaveOnlyOption={!!onSaveEditOnly}
}} {onCancelEdit}
placeholder="Edit your message..." {onSaveEdit}
></textarea> {onSaveEditOnly}
{onEditKeydown}
<div class="mt-2 flex justify-end gap-2"> {onEditedContentChange}
<Button class="h-8 px-3" onclick={onCancelEdit} size="sm" variant="ghost"> {onEditedExtrasChange}
<X class="mr-1 h-3 w-3" /> {onEditedUploadedFilesChange}
Cancel />
</Button>
{#if onSaveEditOnly}
<Button
class="h-8 px-3"
onclick={onSaveEditOnly}
disabled={!editedContent.trim()}
size="sm"
variant="outline"
>
<Check class="mr-1 h-3 w-3" />
Save
</Button>
{/if}
<Button class="h-8 px-3" onclick={onSaveEdit} disabled={!editedContent.trim()} size="sm">
<Send class="mr-1 h-3 w-3" />
Send
</Button>
</div>
</div>
{:else} {:else}
{#if message.extra && message.extra.length > 0} {#if message.extra && message.extra.length > 0}
<div class="mb-2 max-w-[80%]"> <div class="mb-2 max-w-[80%]">

View File

@ -66,10 +66,14 @@
await conversationsStore.navigateToSibling(siblingId); await conversationsStore.navigateToSibling(siblingId);
} }
async function handleEditWithBranching(message: DatabaseMessage, newContent: string) { async function handleEditWithBranching(
message: DatabaseMessage,
newContent: string,
newExtras?: DatabaseMessageExtra[]
) {
onUserAction?.(); onUserAction?.();
await chatStore.editMessageWithBranching(message.id, newContent); await chatStore.editMessageWithBranching(message.id, newContent, newExtras);
refreshAllMessages(); refreshAllMessages();
} }
@ -104,11 +108,12 @@
async function handleEditUserMessagePreserveResponses( async function handleEditUserMessagePreserveResponses(
message: DatabaseMessage, message: DatabaseMessage,
newContent: string newContent: string,
newExtras?: DatabaseMessageExtra[]
) { ) {
onUserAction?.(); onUserAction?.();
await chatStore.editUserMessagePreserveResponses(message.id, newContent); await chatStore.editUserMessagePreserveResponses(message.id, newContent, newExtras);
refreshAllMessages(); refreshAllMessages();
} }

View File

@ -17,7 +17,13 @@
AUTO_SCROLL_INTERVAL, AUTO_SCROLL_INTERVAL,
INITIAL_SCROLL_DELAY INITIAL_SCROLL_DELAY
} from '$lib/constants/auto-scroll'; } from '$lib/constants/auto-scroll';
import { chatStore, errorDialog, isLoading } from '$lib/stores/chat.svelte'; import {
chatStore,
errorDialog,
isLoading,
isEditing,
getAddFilesHandler
} from '$lib/stores/chat.svelte';
import { import {
conversationsStore, conversationsStore,
activeMessages, activeMessages,
@ -181,7 +187,18 @@
dragCounter = 0; dragCounter = 0;
if (event.dataTransfer?.files) { if (event.dataTransfer?.files) {
processFiles(Array.from(event.dataTransfer.files)); const files = Array.from(event.dataTransfer.files);
if (isEditing()) {
const handler = getAddFilesHandler();
if (handler) {
handler(files);
return;
}
}
processFiles(files);
} }
} }
@ -410,7 +427,7 @@
<div class="conversation-chat-form pointer-events-auto rounded-t-3xl pb-4"> <div class="conversation-chat-form pointer-events-auto rounded-t-3xl pb-4">
<ChatForm <ChatForm
disabled={hasPropsError} disabled={hasPropsError || isEditing()}
isLoading={isCurrentConversationLoading} isLoading={isCurrentConversationLoading}
onFileRemove={handleFileRemove} onFileRemove={handleFileRemove}
onFileUpload={handleFileUpload} onFileUpload={handleFileUpload}

View File

@ -0,0 +1,7 @@
import Root from './switch.svelte';
export {
Root,
//
Root as Switch
};

View File

@ -0,0 +1,29 @@
<script lang="ts">
import { Switch as SwitchPrimitive } from 'bits-ui';
import { cn, type WithoutChildrenOrChild } from '$lib/components/ui/utils.js';
let {
ref = $bindable(null),
class: className,
checked = $bindable(false),
...restProps
}: WithoutChildrenOrChild<SwitchPrimitive.RootProps> = $props();
</script>
<SwitchPrimitive.Root
bind:ref
bind:checked
data-slot="switch"
class={cn(
'peer inline-flex h-[1.15rem] w-8 shrink-0 items-center rounded-full border border-transparent shadow-xs transition-all outline-none focus-visible:border-ring focus-visible:ring-[3px] focus-visible:ring-ring/50 disabled:cursor-not-allowed disabled:opacity-50 data-[state=checked]:bg-primary data-[state=unchecked]:bg-input dark:data-[state=unchecked]:bg-input/80',
className
)}
{...restProps}
>
<SwitchPrimitive.Thumb
data-slot="switch-thumb"
class={cn(
'pointer-events-none block size-4 rounded-full bg-background ring-0 transition-transform data-[state=checked]:translate-x-[calc(100%-2px)] data-[state=unchecked]:translate-x-0 dark:data-[state=checked]:bg-primary-foreground dark:data-[state=unchecked]:bg-foreground'
)}
/>
</SwitchPrimitive.Root>

View File

@ -74,6 +74,8 @@ class ChatStore {
private processingStates = new SvelteMap<string, ApiProcessingState | null>(); private processingStates = new SvelteMap<string, ApiProcessingState | null>();
private activeConversationId = $state<string | null>(null); private activeConversationId = $state<string | null>(null);
private isStreamingActive = $state(false); private isStreamingActive = $state(false);
private isEditModeActive = $state(false);
private addFilesHandler: ((files: File[]) => void) | null = $state(null);
// ───────────────────────────────────────────────────────────────────────────── // ─────────────────────────────────────────────────────────────────────────────
// Loading State // Loading State
@ -965,230 +967,9 @@ class ChatStore {
// Editing // Editing
// ───────────────────────────────────────────────────────────────────────────── // ─────────────────────────────────────────────────────────────────────────────
async editAssistantMessage( clearEditMode(): void {
messageId: string, this.isEditModeActive = false;
newContent: string, this.addFilesHandler = null;
shouldBranch: boolean
): Promise<void> {
const activeConv = conversationsStore.activeConversation;
if (!activeConv || this.isLoading) return;
const result = this.getMessageByIdWithRole(messageId, 'assistant');
if (!result) return;
const { message: msg, index: idx } = result;
try {
if (shouldBranch) {
const newMessage = await DatabaseService.createMessageBranch(
{
convId: msg.convId,
type: msg.type,
timestamp: Date.now(),
role: msg.role,
content: newContent,
thinking: msg.thinking || '',
toolCalls: msg.toolCalls || '',
children: [],
model: msg.model
},
msg.parent!
);
await conversationsStore.updateCurrentNode(newMessage.id);
} else {
await DatabaseService.updateMessage(msg.id, { content: newContent, timestamp: Date.now() });
await conversationsStore.updateCurrentNode(msg.id);
conversationsStore.updateMessageAtIndex(idx, {
content: newContent,
timestamp: Date.now()
});
}
conversationsStore.updateConversationTimestamp();
await conversationsStore.refreshActiveMessages();
} catch (error) {
console.error('Failed to edit assistant message:', error);
}
}
async editUserMessagePreserveResponses(messageId: string, newContent: string): Promise<void> {
const activeConv = conversationsStore.activeConversation;
if (!activeConv) return;
const result = this.getMessageByIdWithRole(messageId, 'user');
if (!result) return;
const { message: msg, index: idx } = result;
try {
await DatabaseService.updateMessage(messageId, {
content: newContent,
timestamp: Date.now()
});
conversationsStore.updateMessageAtIndex(idx, { content: newContent, timestamp: Date.now() });
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
if (rootMessage && msg.parent === rootMessage.id && newContent.trim()) {
await conversationsStore.updateConversationTitleWithConfirmation(
activeConv.id,
newContent.trim(),
conversationsStore.titleUpdateConfirmationCallback
);
}
conversationsStore.updateConversationTimestamp();
} catch (error) {
console.error('Failed to edit user message:', error);
}
}
async editMessageWithBranching(messageId: string, newContent: string): Promise<void> {
const activeConv = conversationsStore.activeConversation;
if (!activeConv || this.isLoading) return;
let result = this.getMessageByIdWithRole(messageId, 'user');
if (!result) {
result = this.getMessageByIdWithRole(messageId, 'system');
}
if (!result) return;
const { message: msg } = result;
try {
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
const isFirstUserMessage =
msg.role === 'user' && rootMessage && msg.parent === rootMessage.id;
const parentId = msg.parent || rootMessage?.id;
if (!parentId) return;
const newMessage = await DatabaseService.createMessageBranch(
{
convId: msg.convId,
type: msg.type,
timestamp: Date.now(),
role: msg.role,
content: newContent,
thinking: msg.thinking || '',
toolCalls: msg.toolCalls || '',
children: [],
extra: msg.extra ? JSON.parse(JSON.stringify(msg.extra)) : undefined,
model: msg.model
},
parentId
);
await conversationsStore.updateCurrentNode(newMessage.id);
conversationsStore.updateConversationTimestamp();
if (isFirstUserMessage && newContent.trim()) {
await conversationsStore.updateConversationTitleWithConfirmation(
activeConv.id,
newContent.trim(),
conversationsStore.titleUpdateConfirmationCallback
);
}
await conversationsStore.refreshActiveMessages();
if (msg.role === 'user') {
await this.generateResponseForMessage(newMessage.id);
}
} catch (error) {
console.error('Failed to edit message with branching:', error);
}
}
async regenerateMessageWithBranching(messageId: string, modelOverride?: string): Promise<void> {
const activeConv = conversationsStore.activeConversation;
if (!activeConv || this.isLoading) return;
try {
const idx = conversationsStore.findMessageIndex(messageId);
if (idx === -1) return;
const msg = conversationsStore.activeMessages[idx];
if (msg.role !== 'assistant') return;
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
const parentMessage = allMessages.find((m) => m.id === msg.parent);
if (!parentMessage) return;
this.setChatLoading(activeConv.id, true);
this.clearChatStreaming(activeConv.id);
const newAssistantMessage = await DatabaseService.createMessageBranch(
{
convId: activeConv.id,
type: 'text',
timestamp: Date.now(),
role: 'assistant',
content: '',
thinking: '',
toolCalls: '',
children: [],
model: null
},
parentMessage.id
);
await conversationsStore.updateCurrentNode(newAssistantMessage.id);
conversationsStore.updateConversationTimestamp();
await conversationsStore.refreshActiveMessages();
const conversationPath = filterByLeafNodeId(
allMessages,
parentMessage.id,
false
) as DatabaseMessage[];
// Use modelOverride if provided, otherwise use the original message's model
// If neither is available, don't pass model (will use global selection)
const modelToUse = modelOverride || msg.model || undefined;
await this.streamChatCompletion(
conversationPath,
newAssistantMessage,
undefined,
undefined,
modelToUse
);
} catch (error) {
if (!this.isAbortError(error))
console.error('Failed to regenerate message with branching:', error);
this.setChatLoading(activeConv?.id || '', false);
}
}
private async generateResponseForMessage(userMessageId: string): Promise<void> {
const activeConv = conversationsStore.activeConversation;
if (!activeConv) return;
this.errorDialogState = null;
this.setChatLoading(activeConv.id, true);
this.clearChatStreaming(activeConv.id);
try {
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
const conversationPath = filterByLeafNodeId(
allMessages,
userMessageId,
false
) as DatabaseMessage[];
const assistantMessage = await DatabaseService.createMessageBranch(
{
convId: activeConv.id,
type: 'text',
timestamp: Date.now(),
role: 'assistant',
content: '',
thinking: '',
toolCalls: '',
children: [],
model: null
},
userMessageId
);
conversationsStore.addMessageToActive(assistantMessage);
await this.streamChatCompletion(conversationPath, assistantMessage);
} catch (error) {
console.error('Failed to generate response:', error);
this.setChatLoading(activeConv.id, false);
}
} }
async continueAssistantMessage(messageId: string): Promise<void> { async continueAssistantMessage(messageId: string): Promise<void> {
@ -1340,19 +1121,284 @@ class ChatStore {
} }
} }
public isChatLoadingPublic(convId: string): boolean { async editAssistantMessage(
return this.isChatLoading(convId); messageId: string,
newContent: string,
shouldBranch: boolean
): Promise<void> {
const activeConv = conversationsStore.activeConversation;
if (!activeConv || this.isLoading) return;
const result = this.getMessageByIdWithRole(messageId, 'assistant');
if (!result) return;
const { message: msg, index: idx } = result;
try {
if (shouldBranch) {
const newMessage = await DatabaseService.createMessageBranch(
{
convId: msg.convId,
type: msg.type,
timestamp: Date.now(),
role: msg.role,
content: newContent,
thinking: msg.thinking || '',
toolCalls: msg.toolCalls || '',
children: [],
model: msg.model
},
msg.parent!
);
await conversationsStore.updateCurrentNode(newMessage.id);
} else {
await DatabaseService.updateMessage(msg.id, { content: newContent });
await conversationsStore.updateCurrentNode(msg.id);
conversationsStore.updateMessageAtIndex(idx, {
content: newContent
});
}
conversationsStore.updateConversationTimestamp();
await conversationsStore.refreshActiveMessages();
} catch (error) {
console.error('Failed to edit assistant message:', error);
}
} }
async editUserMessagePreserveResponses(
messageId: string,
newContent: string,
newExtras?: DatabaseMessageExtra[]
): Promise<void> {
const activeConv = conversationsStore.activeConversation;
if (!activeConv) return;
const result = this.getMessageByIdWithRole(messageId, 'user');
if (!result) return;
const { message: msg, index: idx } = result;
try {
const updateData: Partial<DatabaseMessage> = {
content: newContent
};
// Update extras if provided (including empty array to clear attachments)
// Deep clone to avoid Proxy objects from Svelte reactivity
if (newExtras !== undefined) {
updateData.extra = JSON.parse(JSON.stringify(newExtras));
}
await DatabaseService.updateMessage(messageId, updateData);
conversationsStore.updateMessageAtIndex(idx, updateData);
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
if (rootMessage && msg.parent === rootMessage.id && newContent.trim()) {
await conversationsStore.updateConversationTitleWithConfirmation(
activeConv.id,
newContent.trim(),
conversationsStore.titleUpdateConfirmationCallback
);
}
conversationsStore.updateConversationTimestamp();
} catch (error) {
console.error('Failed to edit user message:', error);
}
}
async editMessageWithBranching(
messageId: string,
newContent: string,
newExtras?: DatabaseMessageExtra[]
): Promise<void> {
const activeConv = conversationsStore.activeConversation;
if (!activeConv || this.isLoading) return;
let result = this.getMessageByIdWithRole(messageId, 'user');
if (!result) {
result = this.getMessageByIdWithRole(messageId, 'system');
}
if (!result) return;
const { message: msg } = result;
try {
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
const isFirstUserMessage =
msg.role === 'user' && rootMessage && msg.parent === rootMessage.id;
const parentId = msg.parent || rootMessage?.id;
if (!parentId) return;
// Use newExtras if provided, otherwise copy existing extras
// Deep clone to avoid Proxy objects from Svelte reactivity
const extrasToUse =
newExtras !== undefined
? JSON.parse(JSON.stringify(newExtras))
: msg.extra
? JSON.parse(JSON.stringify(msg.extra))
: undefined;
const newMessage = await DatabaseService.createMessageBranch(
{
convId: msg.convId,
type: msg.type,
timestamp: Date.now(),
role: msg.role,
content: newContent,
thinking: msg.thinking || '',
toolCalls: msg.toolCalls || '',
children: [],
extra: extrasToUse,
model: msg.model
},
parentId
);
await conversationsStore.updateCurrentNode(newMessage.id);
conversationsStore.updateConversationTimestamp();
if (isFirstUserMessage && newContent.trim()) {
await conversationsStore.updateConversationTitleWithConfirmation(
activeConv.id,
newContent.trim(),
conversationsStore.titleUpdateConfirmationCallback
);
}
await conversationsStore.refreshActiveMessages();
if (msg.role === 'user') {
await this.generateResponseForMessage(newMessage.id);
}
} catch (error) {
console.error('Failed to edit message with branching:', error);
}
}
async regenerateMessageWithBranching(messageId: string, modelOverride?: string): Promise<void> {
const activeConv = conversationsStore.activeConversation;
if (!activeConv || this.isLoading) return;
try {
const idx = conversationsStore.findMessageIndex(messageId);
if (idx === -1) return;
const msg = conversationsStore.activeMessages[idx];
if (msg.role !== 'assistant') return;
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
const parentMessage = allMessages.find((m) => m.id === msg.parent);
if (!parentMessage) return;
this.setChatLoading(activeConv.id, true);
this.clearChatStreaming(activeConv.id);
const newAssistantMessage = await DatabaseService.createMessageBranch(
{
convId: activeConv.id,
type: 'text',
timestamp: Date.now(),
role: 'assistant',
content: '',
thinking: '',
toolCalls: '',
children: [],
model: null
},
parentMessage.id
);
await conversationsStore.updateCurrentNode(newAssistantMessage.id);
conversationsStore.updateConversationTimestamp();
await conversationsStore.refreshActiveMessages();
const conversationPath = filterByLeafNodeId(
allMessages,
parentMessage.id,
false
) as DatabaseMessage[];
// Use modelOverride if provided, otherwise use the original message's model
// If neither is available, don't pass model (will use global selection)
const modelToUse = modelOverride || msg.model || undefined;
await this.streamChatCompletion(
conversationPath,
newAssistantMessage,
undefined,
undefined,
modelToUse
);
} catch (error) {
if (!this.isAbortError(error))
console.error('Failed to regenerate message with branching:', error);
this.setChatLoading(activeConv?.id || '', false);
}
}
private async generateResponseForMessage(userMessageId: string): Promise<void> {
const activeConv = conversationsStore.activeConversation;
if (!activeConv) return;
this.errorDialogState = null;
this.setChatLoading(activeConv.id, true);
this.clearChatStreaming(activeConv.id);
try {
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
const conversationPath = filterByLeafNodeId(
allMessages,
userMessageId,
false
) as DatabaseMessage[];
const assistantMessage = await DatabaseService.createMessageBranch(
{
convId: activeConv.id,
type: 'text',
timestamp: Date.now(),
role: 'assistant',
content: '',
thinking: '',
toolCalls: '',
children: [],
model: null
},
userMessageId
);
conversationsStore.addMessageToActive(assistantMessage);
await this.streamChatCompletion(conversationPath, assistantMessage);
} catch (error) {
console.error('Failed to generate response:', error);
this.setChatLoading(activeConv.id, false);
}
}
getAddFilesHandler(): ((files: File[]) => void) | null {
return this.addFilesHandler;
}
public getAllLoadingChats(): string[] {
return Array.from(this.chatLoadingStates.keys());
}
public getAllStreamingChats(): string[] {
return Array.from(this.chatStreamingStates.keys());
}
public getChatStreamingPublic( public getChatStreamingPublic(
convId: string convId: string
): { response: string; messageId: string } | undefined { ): { response: string; messageId: string } | undefined {
return this.getChatStreaming(convId); return this.getChatStreaming(convId);
} }
public getAllLoadingChats(): string[] {
return Array.from(this.chatLoadingStates.keys()); public isChatLoadingPublic(convId: string): boolean {
return this.isChatLoading(convId);
} }
public getAllStreamingChats(): string[] {
return Array.from(this.chatStreamingStates.keys()); isEditing(): boolean {
return this.isEditModeActive;
}
setEditModeActive(handler: (files: File[]) => void): void {
this.isEditModeActive = true;
this.addFilesHandler = handler;
} }
// ───────────────────────────────────────────────────────────────────────────── // ─────────────────────────────────────────────────────────────────────────────
@ -1416,13 +1462,17 @@ class ChatStore {
export const chatStore = new ChatStore(); export const chatStore = new ChatStore();
export const isLoading = () => chatStore.isLoading; export const activeProcessingState = () => chatStore.activeProcessingState;
export const clearEditMode = () => chatStore.clearEditMode();
export const currentResponse = () => chatStore.currentResponse; export const currentResponse = () => chatStore.currentResponse;
export const errorDialog = () => chatStore.errorDialogState; export const errorDialog = () => chatStore.errorDialogState;
export const activeProcessingState = () => chatStore.activeProcessingState; export const getAddFilesHandler = () => chatStore.getAddFilesHandler();
export const isChatStreaming = () => chatStore.isStreaming();
export const isChatLoading = (convId: string) => chatStore.isChatLoadingPublic(convId);
export const getChatStreaming = (convId: string) => chatStore.getChatStreamingPublic(convId);
export const getAllLoadingChats = () => chatStore.getAllLoadingChats(); export const getAllLoadingChats = () => chatStore.getAllLoadingChats();
export const getAllStreamingChats = () => chatStore.getAllStreamingChats(); export const getAllStreamingChats = () => chatStore.getAllStreamingChats();
export const getChatStreaming = (convId: string) => chatStore.getChatStreamingPublic(convId);
export const isChatLoading = (convId: string) => chatStore.isChatLoadingPublic(convId);
export const isChatStreaming = () => chatStore.isStreaming();
export const isEditing = () => chatStore.isEditing();
export const isLoading = () => chatStore.isLoading;
export const setEditModeActive = (handler: (files: File[]) => void) =>
chatStore.setEditModeActive(handler);