diff --git a/.github/workflows/server-webui.yml b/.github/workflows/server-webui.yml new file mode 100644 index 0000000000..f8a261eefa --- /dev/null +++ b/.github/workflows/server-webui.yml @@ -0,0 +1,295 @@ +# Server WebUI build and tests +name: Server WebUI + +on: + workflow_dispatch: # allows manual triggering + inputs: + sha: + description: 'Commit SHA1 to build' + required: false + type: string + slow_tests: + description: 'Run slow tests' + required: true + type: boolean + push: + branches: + - master + paths: ['.github/workflows/server-webui.yml', 'tools/server/webui/**.*', 'tools/server/tests/**.*', 'tools/server/public/**'] + pull_request: + types: [opened, synchronize, reopened] + paths: ['.github/workflows/server-webui.yml', 'tools/server/webui/**.*', 'tools/server/tests/**.*', 'tools/server/public/**'] + +env: + LLAMA_LOG_COLORS: 1 + LLAMA_LOG_PREFIX: 1 + LLAMA_LOG_TIMESTAMPS: 1 + LLAMA_LOG_VERBOSITY: 10 + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + webui-setup: + name: WebUI Setup + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: "22" + cache: "npm" + cache-dependency-path: "tools/server/webui/package-lock.json" + + - name: Cache node_modules + uses: actions/cache@v4 + id: cache-node-modules + with: + path: tools/server/webui/node_modules + key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }} + restore-keys: | + ${{ runner.os }}-node-modules- + + - name: Install dependencies + if: steps.cache-node-modules.outputs.cache-hit != 'true' + run: npm ci + working-directory: tools/server/webui + + webui-check: + needs: webui-setup + name: WebUI Check + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: "22" + + - name: Restore node_modules cache + uses: actions/cache@v4 + with: + path: tools/server/webui/node_modules + key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }} + restore-keys: | + ${{ runner.os }}-node-modules- + + - name: Run type checking + run: npm run check + working-directory: tools/server/webui + + - name: Run linting + run: npm run lint + working-directory: tools/server/webui + + webui-build: + needs: webui-check + name: WebUI Build + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: "22" + + - name: Restore node_modules cache + uses: actions/cache@v4 + with: + path: tools/server/webui/node_modules + key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }} + restore-keys: | + ${{ runner.os }}-node-modules- + + - name: Build application + run: npm run build + working-directory: tools/server/webui + + webui-tests: + needs: webui-build + name: Run WebUI tests + permissions: + contents: read + + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: "22" + + - name: Restore node_modules cache + uses: actions/cache@v4 + with: + path: tools/server/webui/node_modules + key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }} + restore-keys: | + ${{ runner.os }}-node-modules- + + - name: Install Playwright browsers + run: npx playwright install --with-deps + working-directory: tools/server/webui + + - name: Build Storybook + run: npm run build-storybook + working-directory: tools/server/webui + + - name: Run Client tests + run: npm run test:client + working-directory: tools/server/webui + + - name: Run Server tests + run: npm run test:server + working-directory: tools/server/webui + + - name: Run UI tests + run: npm run test:ui -- --testTimeout=60000 + working-directory: tools/server/webui + + - name: Run E2E tests + run: npm run test:e2e + working-directory: tools/server/webui + + server-build: + needs: [webui-tests] + runs-on: ubuntu-latest + + strategy: + matrix: + sanitizer: [ADDRESS, UNDEFINED] # THREAD is broken + build_type: [RelWithDebInfo] + include: + - build_type: Release + sanitizer: "" + fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken + + steps: + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get -y install \ + build-essential \ + xxd \ + git \ + cmake \ + curl \ + wget \ + language-pack-en \ + libssl-dev + + - name: Clone + id: checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} + + - name: Python setup + id: setup_python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Tests dependencies + id: test_dependencies + run: | + pip install -r tools/server/tests/requirements.txt + + - name: Setup Node.js for WebUI + uses: actions/setup-node@v4 + with: + node-version: "22" + cache: "npm" + cache-dependency-path: "tools/server/webui/package-lock.json" + + - name: Install WebUI dependencies + run: npm ci + working-directory: tools/server/webui + + - name: Build WebUI + run: npm run build + working-directory: tools/server/webui + + - name: Build (no OpenMP) + id: cmake_build_no_openmp + if: ${{ matrix.sanitizer == 'THREAD' }} + run: | + cmake -B build \ + -DGGML_NATIVE=OFF \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ + -DLLAMA_BUILD_SERVER=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ + -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DGGML_OPENMP=OFF ; + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server + + - name: Build (sanitizers) + id: cmake_build_sanitizers + if: ${{ matrix.sanitizer != '' && matrix.sanitizer != 'THREAD' }} + run: | + cmake -B build \ + -DGGML_NATIVE=OFF \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ + -DLLAMA_BUILD_SERVER=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ + -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ; + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server + + - name: Build (sanitizers) + id: cmake_build + if: ${{ matrix.sanitizer == '' }} + run: | + cmake -B build \ + -DGGML_NATIVE=OFF \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ + -DLLAMA_BUILD_SERVER=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ; + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server + + - name: Tests + id: server_integration_tests + if: ${{ matrix.sanitizer == '' }} + env: + GITHUB_ACTIONS: "true" + run: | + cd tools/server/tests + ./tests.sh + + - name: Tests (sanitizers) + id: server_integration_tests_sanitizers + if: ${{ matrix.sanitizer != '' }} + run: | + cd tools/server/tests + LLAMA_SANITIZE=1 ./tests.sh + + - name: Slow tests + id: server_integration_tests_slow + if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }} + run: | + cd tools/server/tests + SLOW_TESTS=1 ./tests.sh diff --git a/.github/workflows/server.yml b/.github/workflows/server.yml index a57d0e8b1c..f9e2a79af7 100644 --- a/.github/workflows/server.yml +++ b/.github/workflows/server.yml @@ -76,270 +76,6 @@ jobs: run: | pip install -r tools/server/tests/requirements.txt - webui-setup: - name: WebUI Setup - runs-on: ubuntu-latest - steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - fetch-depth: 0 - ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} - - - name: Setup Node.js - uses: actions/setup-node@v4 - with: - node-version: "22" - cache: "npm" - cache-dependency-path: "tools/server/webui/package-lock.json" - - - name: Cache node_modules - uses: actions/cache@v4 - id: cache-node-modules - with: - path: tools/server/webui/node_modules - key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }} - restore-keys: | - ${{ runner.os }}-node-modules- - - - name: Install dependencies - if: steps.cache-node-modules.outputs.cache-hit != 'true' - run: npm ci - working-directory: tools/server/webui - - webui-check: - needs: webui-setup - name: WebUI Check - runs-on: ubuntu-latest - steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - fetch-depth: 0 - ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} - - - name: Setup Node.js - uses: actions/setup-node@v4 - with: - node-version: "22" - - - name: Restore node_modules cache - uses: actions/cache@v4 - with: - path: tools/server/webui/node_modules - key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }} - restore-keys: | - ${{ runner.os }}-node-modules- - - - name: Run type checking - run: npm run check - working-directory: tools/server/webui - - - name: Run linting - run: npm run lint - working-directory: tools/server/webui - - webui-build: - needs: webui-check - name: WebUI Build - runs-on: ubuntu-latest - steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - fetch-depth: 0 - ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} - - - name: Setup Node.js - uses: actions/setup-node@v4 - with: - node-version: "22" - - - name: Restore node_modules cache - uses: actions/cache@v4 - with: - path: tools/server/webui/node_modules - key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }} - restore-keys: | - ${{ runner.os }}-node-modules- - - - name: Build application - run: npm run build - working-directory: tools/server/webui - - webui-tests: - needs: webui-build - name: Run WebUI tests - permissions: - contents: read - - runs-on: ubuntu-latest - - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Setup Node.js - uses: actions/setup-node@v4 - with: - node-version: "22" - - - name: Restore node_modules cache - uses: actions/cache@v4 - with: - path: tools/server/webui/node_modules - key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }} - restore-keys: | - ${{ runner.os }}-node-modules- - - - name: Install Playwright browsers - run: npx playwright install --with-deps - working-directory: tools/server/webui - - - name: Build Storybook - run: npm run build-storybook - working-directory: tools/server/webui - - - name: Run Client tests - run: npm run test:client - working-directory: tools/server/webui - - - name: Run Server tests - run: npm run test:server - working-directory: tools/server/webui - - - name: Run UI tests - run: npm run test:ui -- --testTimeout=60000 - working-directory: tools/server/webui - - - name: Run E2E tests - run: npm run test:e2e - working-directory: tools/server/webui - - server-build: - needs: [webui-tests] - runs-on: ubuntu-latest - - strategy: - matrix: - sanitizer: [ADDRESS, UNDEFINED] # THREAD is broken - build_type: [RelWithDebInfo] - include: - - build_type: Release - sanitizer: "" - fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken - - steps: - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get -y install \ - build-essential \ - xxd \ - git \ - cmake \ - curl \ - wget \ - language-pack-en \ - libssl-dev - - - name: Clone - id: checkout - uses: actions/checkout@v4 - with: - fetch-depth: 0 - ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} - - - name: Python setup - id: setup_python - uses: actions/setup-python@v5 - with: - python-version: '3.11' - - - name: Tests dependencies - id: test_dependencies - run: | - pip install -r tools/server/tests/requirements.txt - - - name: Setup Node.js for WebUI - uses: actions/setup-node@v4 - with: - node-version: "22" - cache: "npm" - cache-dependency-path: "tools/server/webui/package-lock.json" - - - name: Install WebUI dependencies - run: npm ci - working-directory: tools/server/webui - - - name: Build WebUI - run: npm run build - working-directory: tools/server/webui - - - name: Build (no OpenMP) - id: cmake_build_no_openmp - if: ${{ matrix.sanitizer == 'THREAD' }} - run: | - cmake -B build \ - -DGGML_NATIVE=OFF \ - -DLLAMA_CURL=OFF \ - -DLLAMA_OPENSSL=ON \ - -DLLAMA_BUILD_SERVER=ON \ - -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ - -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ - -DGGML_OPENMP=OFF ; - cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server - - - name: Build (sanitizers) - id: cmake_build_sanitizers - if: ${{ matrix.sanitizer != '' && matrix.sanitizer != 'THREAD' }} - run: | - cmake -B build \ - -DGGML_NATIVE=OFF \ - -DLLAMA_CURL=OFF \ - -DLLAMA_OPENSSL=ON \ - -DLLAMA_BUILD_SERVER=ON \ - -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ - -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ; - cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server - - - name: Build (sanitizers) - id: cmake_build - if: ${{ matrix.sanitizer == '' }} - run: | - cmake -B build \ - -DGGML_NATIVE=OFF \ - -DLLAMA_CURL=OFF \ - -DLLAMA_OPENSSL=ON \ - -DLLAMA_BUILD_SERVER=ON \ - -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ; - cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server - - - name: Tests - id: server_integration_tests - if: ${{ matrix.sanitizer == '' }} - env: - GITHUB_ACTIONS: "true" - run: | - cd tools/server/tests - ./tests.sh - - - name: Tests (sanitizers) - id: server_integration_tests_sanitizers - if: ${{ matrix.sanitizer != '' }} - run: | - cd tools/server/tests - LLAMA_SANITIZE=1 ./tests.sh - - - name: Slow tests - id: server_integration_tests_slow - if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }} - run: | - cd tools/server/tests - SLOW_TESTS=1 ./tests.sh - - server-windows: runs-on: windows-2022 diff --git a/CODEOWNERS b/CODEOWNERS index 8e62a36e81..8a0c98c968 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -87,7 +87,8 @@ /tests/ @ggerganov /tests/test-chat-.* @pwilkin /tools/batched-bench/ @ggerganov -/tools/main/ @ggerganov +/tools/cli/ @ngxson +/tools/completion/ @ggerganov /tools/mtmd/ @ngxson /tools/perplexity/ @ggerganov /tools/quantize/ @ggerganov diff --git a/README.md b/README.md index b7d24c9dd7..5f2076d0a3 100644 --- a/README.md +++ b/README.md @@ -313,7 +313,7 @@ The Hugging Face platform provides a variety of online tools for converting, qua To learn more about model quantization, [read this documentation](tools/quantize/README.md) -## [`llama-cli`](tools/main) +## [`llama-cli`](tools/cli) #### A CLI tool for accessing and experimenting with most of `llama.cpp`'s functionality. @@ -525,7 +525,8 @@ To learn more about model quantization, [read this documentation](tools/quantize ## Other documentation -- [main (cli)](tools/main/README.md) +- [cli](tools/cli/README.md) +- [completion](tools/completion/README.md) - [server](tools/server/README.md) - [GBNF grammars](grammars/README.md) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 9b4c63cf22..f09eab54ac 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -136,19 +136,11 @@ class ModelBase: self.remote_hf_model_id = remote_hf_model_id self.sentence_transformers_dense_modules = sentence_transformers_dense_modules self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams - self.rope_parameters = self.hparams.get("rope_parameters", self.hparams.get("rope_scaling")) or {} self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id) self.metadata_override = metadata_override self.model_name = model_name self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py - # Ensure "rope_theta" and "rope_type" is mirrored in rope_parameters - if "full_attention" not in self.rope_parameters and "sliding_attention" not in self.rope_parameters: - if "rope_theta" not in self.rope_parameters and (rope_theta := self.find_hparam(["rope_theta", "global_rope_theta", "rotary_emb_base"], optional=True)) is not None: - self.rope_parameters["rope_theta"] = rope_theta - if "rope_type" not in self.rope_parameters and (rope_type := self.rope_parameters.get("type")) is not None: - self.rope_parameters["rope_type"] = rope_type - # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type if self.ftype == gguf.LlamaFileType.GUESSED: # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie. @@ -768,6 +760,15 @@ class TextModel(ModelBase): self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + self.rope_parameters = self.hparams.get("rope_parameters", self.hparams.get("rope_scaling")) or {} + + # Ensure "rope_theta" and "rope_type" is mirrored in rope_parameters + if "full_attention" not in self.rope_parameters and "sliding_attention" not in self.rope_parameters: + if "rope_theta" not in self.rope_parameters and (rope_theta := self.find_hparam(["rope_theta", "global_rope_theta", "rotary_emb_base"], optional=True)) is not None: + self.rope_parameters["rope_theta"] = rope_theta + if "rope_type" not in self.rope_parameters and (rope_type := self.rope_parameters.get("type")) is not None: + self.rope_parameters["rope_type"] = rope_type + @classmethod def __init_subclass__(cls): # can't use an abstract property, because overriding it without type errors @@ -1206,6 +1207,9 @@ class TextModel(ModelBase): if chkhsh == "f4f37b6c8eb9ea29b3eac6bb8c8487c5ab7885f8d8022e67edc1c68ce8403e95": # ref: https://huggingface.co/MiniMaxAI/MiniMax-M2 res = "minimax-m2" + if chkhsh == "4a2e2abae11ca2b86d570fc5b44be4d5eb5e72cc8f22dd136a94b37da83ab665": + # ref: https://huggingface.co/KORMo-Team/KORMo-tokenizer + res = "kormo" if res is None: logger.warning("\n") @@ -3401,7 +3405,7 @@ class QwenModel(TextModel): self._set_vocab_qwen() -@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration") +@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration", "KORMoForCausalLM") class Qwen2Model(TextModel): model_arch = gguf.MODEL_ARCH.QWEN2 @@ -8576,8 +8580,18 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel): class NemotronHModel(GraniteHybridModel): """Hybrid mamba2/attention model from NVIDIA""" model_arch = gguf.MODEL_ARCH.NEMOTRON_H + is_moe: bool = False def __init__(self, *args, **kwargs): + # We have to determine the correct model architecture (MoE vs non-MoE) before + # calling the parent __init__. This is because the parent constructor + # uses self.model_arch to build the tensor name map, and all MoE-specific + # mappings would be missed if it were called with the default non-MoE arch. + hparams = ModelBase.load_hparams(args[0], self.is_mistral_format) + if "num_experts_per_tok" in hparams: + self.model_arch = gguf.MODEL_ARCH.NEMOTRON_H_MOE + self.is_moe = True + super().__init__(*args, **kwargs) # Save the top-level head_dim for later @@ -8589,9 +8603,11 @@ class NemotronHModel(GraniteHybridModel): # Update the ssm / attn / mlp layers # M: Mamba2, *: Attention, -: MLP + # MoE: + # M: Mamba2, *: Attention, E: Expert hybrid_override_pattern = self.hparams["hybrid_override_pattern"] self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"] - self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "-"] + self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == ("E" if self.is_moe else "-")] def get_attn_layers(self): hybrid_override_pattern = self.hparams["hybrid_override_pattern"] @@ -8607,10 +8623,28 @@ class NemotronHModel(GraniteHybridModel): # Set feed_forward_length # NOTE: This will trigger an override warning. This is preferrable to # duplicating all the parent logic - n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"]) - self.gguf_writer.add_feed_forward_length([ - n_ff if i in self._mlp_layers else 0 for i in range(self.block_count) - ]) + if not self.is_moe: + n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"]) + self.gguf_writer.add_feed_forward_length([ + n_ff if i in self._mlp_layers else 0 for i in range(self.block_count) + ]) + else: + moe_intermediate_size = self.hparams["moe_intermediate_size"] + self.gguf_writer.add_feed_forward_length([ + moe_intermediate_size if i in self._mlp_layers else 0 for i in range(self.block_count) + ]) + self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"]) + self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"]) + self.gguf_writer.add_expert_shared_feed_forward_length(self.hparams["moe_shared_expert_intermediate_size"]) + self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"]) + self.gguf_writer.add_expert_shared_count(self.hparams["n_shared_experts"]) + self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"]) + self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"]) + self.gguf_writer.add_expert_group_count(self.hparams["n_group"]) + + # number of experts used per token (top-k) + if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: + self.gguf_writer.add_expert_used_count(n_experts_used) def set_vocab(self): super().set_vocab() @@ -8618,7 +8652,81 @@ class NemotronHModel(GraniteHybridModel): # The tokenizer _does_ add a BOS token (via post_processor type # TemplateProcessing) but does not set add_bos_token to true in the # config, so we need to explicitly override it here. - self.gguf_writer.add_add_bos_token(True) + if not self.is_moe: + self.gguf_writer.add_add_bos_token(True) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if self.is_moe and bid is not None: + if name.endswith("mixer.gate.e_score_correction_bias"): + new_name = name.replace("e_score_correction_bias", "e_score_correction.bias") + mapped_name = self.map_tensor_name(new_name) + return [(mapped_name, data_torch)] + + if name.endswith("mixer.dt_bias"): + new_name = name.replace("dt_bias", "dt.bias") + mapped_name = self.map_tensor_name(new_name) + return [(mapped_name, data_torch)] + + if name.endswith("mixer.conv1d.weight"): + squeezed_data = data_torch.squeeze() + mapped_name = self.map_tensor_name(name) + return [(mapped_name, squeezed_data)] + + if name.endswith("mixer.A_log"): + transformed_data = -torch.exp(data_torch) + reshaped_data = transformed_data.squeeze().reshape(-1, 1) + mapped_name = self.map_tensor_name(name) + return [(mapped_name, reshaped_data)] + + if name.endswith("mixer.D"): + reshaped_data = data_torch.squeeze().reshape(-1, 1) + mapped_name = self.map_tensor_name(name) + return [(mapped_name, reshaped_data)] + + if name.endswith("mixer.norm.weight"): + reshaped_data = data_torch.reshape(8, 512) + mapped_name = self.map_tensor_name(name) + return [(mapped_name, reshaped_data)] + + if name.find("mixer.experts") != -1: + n_experts = self.hparams["n_routed_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 2: + # merge the experts into a single tensor + tensors: list[tuple[str, Tensor]] = [] + for w_name in ["down_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"backbone.layers.{bid}.mixer.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + new_name = self.map_tensor_name(merged_name) + tensors.append((new_name, data_torch)) + + return tensors + else: + return [] + + return super().modify_tensors(data_torch, name, bid) + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") @ModelBase.register("BailingMoeForCausalLM") diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index b8f694e86c..5e8456a7ea 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -143,6 +143,7 @@ models = [ {"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", }, {"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", }, {"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", }, + {"name": "kormo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/KORMo-Team/KORMo-tokenizer", }, ] # some models are known to be broken upstream, so we will skip them as exceptions diff --git a/docs/development/HOWTO-add-model.md b/docs/development/HOWTO-add-model.md index 5989b873a6..9d1452e3f0 100644 --- a/docs/development/HOWTO-add-model.md +++ b/docs/development/HOWTO-add-model.md @@ -9,7 +9,8 @@ Adding a model requires few steps: After following these steps, you can open PR. Also, it is important to check that the examples and main ggml backends (CUDA, METAL, CPU) are working with the new architecture, especially: -- [main](/tools/main/) +- [cli](/tools/cli/) +- [completion](/tools/completion/) - [imatrix](/tools/imatrix/) - [quantize](/tools/quantize/) - [server](/tools/server/) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 72a82a8911..514f086f68 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -1976,9 +1976,6 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s break; case GGML_TYPE_F16: - if (!opt_experimental) { - return false; - } break; default: diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index c99b6a0d18..346f0bd339 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -903,7 +903,7 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri const float * restrict vy = (const float * restrict) y; for (uint32_t i = 0; i < n; i++) { - rsum += vx[i] * (__fp16) vy[i]; + rsum += (float)vx[i] * vy[i]; } *s = rsum; return; @@ -917,7 +917,7 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri // for some reason we need volatile here so that the compiler doesn't try anything funky volatile HVX_Vector rsum = Q6_V_vsplat_R(0); - + float r_sum_scalar = 0.0f; uint32_t i = 0; for (i = 0; i < nv0; i++) { @@ -926,31 +926,42 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri HVX_Vector x = vx[i]; HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0 - HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp)); - HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp)); + //NOTE: need volatile here to prevent compiler optimization + // Seem compiler cannot guarantee read-after-write?? + volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp)); + volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp)); HVX_Vector sum = Q6_Vqf32_vadd_Vqf32Vqf32(hi, lo); rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum); } if (nv1) { - HVX_VectorPair yp = vy[i]; + // HVX_VectorPair yp = vy[i]; - HVX_Vector x = vx[i]; - HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0 + // HVX_Vector x = vx[i]; + // HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0 - if (nv1 >= 32) { - HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp)); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi); - nv1 -= 32; - } + // if (nv1 >= 32) { + // volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp)); + // rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi); + // nv1 -= 32; + // } + // rsum = hvx_vec_qf32_reduce_sum(rsum); + + // if (nv1) { + // volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp)); + // HVX_Vector sum = hvx_vec_qf32_reduce_sum_n(lo, nv1); + // rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum); + // } + + //process the remainder using scalar loop rsum = hvx_vec_qf32_reduce_sum(rsum); + const __fp16 * restrict sx = (const __fp16 * restrict) x; + const float * restrict sy = (const float * restrict) y; - if (nv1) { - HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp)); - HVX_Vector sum = hvx_vec_qf32_reduce_sum_n(lo, nv1); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum); + for (uint32_t i = nv0 * 64; i < n; i++) { + r_sum_scalar += (float) sx[i] * sy[i]; } // hvx_vec_dump_fp16("X", x); @@ -961,7 +972,7 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri rsum = hvx_vec_qf32_reduce_sum(rsum); } - *s = hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(rsum)); + *s = hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(rsum)) + r_sum_scalar; # ifdef HTP_DEBUG { @@ -1498,9 +1509,6 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0, uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - const size_t src0_row_size = sizeof(__fp16) * ne00; - const size_t src1_row_size = sizeof(float) * ne10; - assert(ne12 % ne02 == 0); assert(ne13 % ne03 == 0); @@ -1510,8 +1518,6 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0, // This is the size of the rest of the dimensions of the result const uint32_t nr1 = ne1 * ne2 * ne3; - uint32_t chunk_size = 64; - // distribute the thread work across the inner or outer loop based on which one is larger uint32_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows uint32_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows @@ -1544,11 +1550,11 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0, const uint32_t blck_0 = 64; const uint32_t blck_1 = 64; - float tmp[32]; + __attribute__((aligned(128))) float tmp[64]; for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { - for (uint32_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1++) { + for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) { const uint32_t i13 = (ir1 / (ne12 * ne1)); const uint32_t i12 = (ir1 - i13 * ne12 * ne1) / ne1; const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); @@ -1561,13 +1567,16 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0, const uint32_t i2 = i12; const uint32_t i3 = i13; - const uint8_t * restrict src0_row = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03); + const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03); const uint8_t * restrict src1_col = - (const uint8_t *) src1->data + (i11 + i12 * ne11 + i13 * ne12 * ne11) * src1_row_size; + (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13); float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); - for (uint32_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0++) { - vec_dot_f16_f32(ne00, &tmp[ir0 - iir0], src0_row + ir0 * src0_row_size, src1_col); + const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end); + for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) { + // Use nb01 stride for non-contiguous src0 support + const uint8_t * restrict src0_row = src0_base + ir0 * nb01; + vec_dot_f16_f32(ne00, &tmp[ir0 - iir0], src0_row, src1_col); } hvx_copy_fp32_ua((uint8_t *) &dst_col[iir0], (uint8_t *) tmp, MIN(iir0 + blck_0, ir0_end) - iir0); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 7b7d1c1233..f24270bb1c 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -769,9 +769,16 @@ ggml_metal_device_t ggml_metal_device_init(void) { #endif dev->props.use_shared_buffers = dev->props.has_unified_memory; +#if TARGET_OS_OSX + // In case of eGPU, shared memory may be preferable. + dev->props.use_shared_buffers |= [dev->mtl_device location] == MTLDeviceLocationExternal; +#endif if (getenv("GGML_METAL_SHARED_BUFFERS_DISABLE") != NULL) { dev->props.use_shared_buffers = false; } + if (getenv("GGML_METAL_SHARED_BUFFERS_ENABLE") != NULL) { + dev->props.use_shared_buffers = true; + } dev->props.supports_gpu_family_apple7 = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7]; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 2648fedaa7..aa005d49b2 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -420,6 +420,7 @@ class MODEL_ARCH(IntEnum): JAIS = auto() NEMOTRON = auto() NEMOTRON_H = auto() + NEMOTRON_H_MOE = auto() EXAONE = auto() EXAONE4 = auto() GRANITE = auto() @@ -810,6 +811,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.JAIS: "jais", MODEL_ARCH.NEMOTRON: "nemotron", MODEL_ARCH.NEMOTRON_H: "nemotron_h", + MODEL_ARCH.NEMOTRON_H_MOE: "nemotron_h_moe", MODEL_ARCH.EXAONE: "exaone", MODEL_ARCH.EXAONE4: "exaone4", MODEL_ARCH.GRANITE: "granite", @@ -2618,6 +2620,33 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.NEMOTRON_H_MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + # experts + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + # shared expert + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + ], MODEL_ARCH.EXAONE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 35b5873e44..0ec28034a0 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -154,7 +154,8 @@ class TensorNameMap: "model.layers.{bid}.operator_norm", # lfm2 "model.transformer.blocks.{bid}.attn_norm", # llada "layers.{bid}.input_layernorm", # qwen3-embedding - "model.layers.{bid}.attention_layernorm" # apertus + "model.layers.{bid}.attention_layernorm", # apertus + "model.layers.{bid}.pre_attention_layernorm", # kormo ), # Attention norm 2 @@ -342,6 +343,7 @@ class TensorNameMap: "model.transformer.blocks.{bid}.ff_norm", # llada "layers.{bid}.post_attention_layernorm", # qwen3-embedding "model.layers.{bid}.feedforward_layernorm", # apertus + "model.layers.{bid}.pre_mlp_layernorm", # kormo ), # Pre feed-forward norm @@ -377,6 +379,7 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.gate", # lfm2moe "model.layers.{bid}.mlp.router.gate", # afmoe "layers.{bid}.gate", # mistral-large + "backbone.layers.{bid}.mixer.gate", # nemotron-h-moe ), MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( @@ -390,6 +393,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.expert_bias", # afmoe "model.layers.{bid}.feed_forward.expert_bias", # lfm2moe "model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2 + "backbone.layers.{bid}.mixer.gate.e_score_correction" # nemotron-h-moe ), # Feed-forward up @@ -438,7 +442,7 @@ class TensorNameMap: "layers.{bid}.feed_forward.experts.w3", # mixtral (merged) "transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged) "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx - "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) ernie4.5-moe + "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) ernie4.5-moe, nemotron-h-moe (merged) "model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged) "model.layers.{bid}.feed_forward.experts.up_proj", # llama4 "encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe @@ -452,6 +456,7 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.down_proj", "model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan "layers.{bid}.shared_experts.w3", # mistral-large + "backbone.layers.{bid}.mixer.shared_experts.up_proj", # nemotron-h-moe ), MODEL_TENSOR.FFN_UP_CHEXP: ( @@ -546,7 +551,7 @@ class TensorNameMap: "layers.{bid}.feed_forward.experts.w2", # mixtral (merged) "transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged) "transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx - "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) ernie4.5-moe + "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) ernie4.5-moe nemotron-h-moe (merged) "model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe "model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged) "model.layers.{bid}.feed_forward.experts.down_proj", # llama4 @@ -561,6 +566,7 @@ class TensorNameMap: "model.layers.{bid}.shared_mlp.output_linear", # granitemoe "model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan "layers.{bid}.shared_experts.w2", # mistral-large + "backbone.layers.{bid}.mixer.shared_experts.down_proj", # nemotron-h-moe ), MODEL_TENSOR.FFN_DOWN_CHEXP: ( @@ -704,6 +710,7 @@ class TensorNameMap: "model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid "model.layers.layers.{bid}.mixer.dt_proj", # plamo2 "model.layers.{bid}.linear_attn.dt_proj", # qwen3next + "backbone.layers.{bid}.mixer.dt", # nemotron-h-moe ), MODEL_TENSOR.SSM_DT_NORM: ( diff --git a/grammars/README.md b/grammars/README.md index 11e3b6dd90..daac7f4d8d 100644 --- a/grammars/README.md +++ b/grammars/README.md @@ -1,6 +1,6 @@ # GBNF Guide -GBNF (GGML BNF) is a format for defining [formal grammars](https://en.wikipedia.org/wiki/Formal_grammar) to constrain model outputs in `llama.cpp`. For example, you can use it to force the model to generate valid JSON, or speak only in emojis. GBNF grammars are supported in various ways in `tools/main` and `tools/server`. +GBNF (GGML BNF) is a format for defining [formal grammars](https://en.wikipedia.org/wiki/Formal_grammar) to constrain model outputs in `llama.cpp`. For example, you can use it to force the model to generate valid JSON, or speak only in emojis. GBNF grammars are supported in various ways in `tools/cli`, `tools/completion` and `tools/server`. ## Background @@ -135,7 +135,7 @@ While semantically correct, the syntax `x? x? x?.... x?` (with N repetitions) ma You can use GBNF grammars: - In [llama-server](../tools/server)'s completion endpoints, passed as the `grammar` body field -- In [llama-cli](../tools/main), passed as the `--grammar` & `--grammar-file` flags +- In [llama-cli](../tools/cli) and [llama-completion](../tools/completion), passed as the `--grammar` & `--grammar-file` flags - With [test-gbnf-validator](../tests/test-gbnf-validator.cpp), to test them against strings. ## JSON Schemas → GBNF @@ -145,7 +145,7 @@ You can use GBNF grammars: - In [llama-server](../tools/server): - For any completion endpoints, passed as the `json_schema` body field - For the `/chat/completions` endpoint, passed inside the `response_format` body field (e.g. `{"type", "json_object", "schema": {"items": {}}}` or `{ type: "json_schema", json_schema: {"schema": ...} }`) -- In [llama-cli](../tools/main), passed as the `--json` / `-j` flag +- In [llama-cli](../tools/cli) and [llama-completion](../tools/completion), passed as the `--json` / `-j` flag - To convert to a grammar ahead of time: - in CLI, with [examples/json_schema_to_grammar.py](../examples/json_schema_to_grammar.py) - in JavaScript with [json-schema-to-grammar.mjs](../tools/server/public_legacy/json-schema-to-grammar.mjs) (this is used by the [server](../tools/server)'s Web UI) diff --git a/scripts/snapdragon/adb/run-mtmd.sh b/scripts/snapdragon/adb/run-mtmd.sh new file mode 100755 index 0000000000..91d868278a --- /dev/null +++ b/scripts/snapdragon/adb/run-mtmd.sh @@ -0,0 +1,65 @@ +#!/bin/sh +# + +# Basedir on device +basedir=/data/local/tmp/llama.cpp + +cli_opts= + +branch=. +[ "$B" != "" ] && branch=$B + +adbserial= +[ "$S" != "" ] && adbserial="-s $S" + +model="gemma-3-4b-it-Q4_0.gguf" +[ "$M" != "" ] && model="$M" + +mmproj="mmproj-F16.gguf" +[ "$MMPROJ" != "" ] && mmproj="$MMPROJ" + +image= +[ "$IMG" != "" ] && image="$IMG" + +device="HTP0" +[ "$D" != "" ] && device="$D" + +verbose= +[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" + +experimental="GGML_HEXAGON_EXPERIMENTAL=1" +[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E" + +sched= +[ "$SCHED" != "" ] && sched="GGML_SCHED_DEBUG=2" cli_opts="$cli_opts -v" + +profile= +[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" + +opmask= +[ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK" + +nhvx= +[ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX" + +ndev= +[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV" + +# MTMD backend device for vision model (defaults to CPU if not set) +mtmd_backend= +[ "$MTMD_DEVICE" != "" ] && mtmd_backend="MTMD_BACKEND_DEVICE=$MTMD_DEVICE" + +set -x + +adb $adbserial shell " \ + cd $basedir; ulimit -c unlimited; \ + LD_LIBRARY_PATH=$basedir/$branch/lib \ + ADSP_LIBRARY_PATH=$basedir/$branch/lib \ + $verbose $experimental $sched $opmask $profile $nhvx $ndev $mtmd_backend \ + ./$branch/bin/llama-mtmd-cli --no-mmap -m $basedir/../gguf/$model \ + --mmproj $basedir/../gguf/$mmproj \ + --image $basedir/../gguf/$image \ + --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ + --ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on \ + -ngl 99 --device $device -v $cli_opts $@ \ +" diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 1cb91209f5..e49ad04e8b 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -76,6 +76,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_JAIS, "jais" }, { LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_NEMOTRON_H, "nemotron_h" }, + { LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" }, { LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_EXAONE4, "exaone4" }, { LLM_ARCH_RWKV6, "rwkv6" }, @@ -1800,6 +1801,39 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_NEMOTRON_H_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + // mamba(2) ssm layers + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + // attention layers + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + // dense FFN + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + // MoE FFN (for MoE layers) + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_EXP_PROBS_B,"blk.%d.exp_probs_b" }, + // MoE shared expert layer + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, { LLM_ARCH_EXAONE, { @@ -2854,6 +2888,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_LFM2: case LLM_ARCH_LFM2MOE: case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: case LLM_ARCH_QWEN3NEXT: return true; default: diff --git a/src/llama-arch.h b/src/llama-arch.h index f01e7c36b8..09b6c7ffcb 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -80,6 +80,7 @@ enum llm_arch { LLM_ARCH_JAIS, LLM_ARCH_NEMOTRON, LLM_ARCH_NEMOTRON_H, + LLM_ARCH_NEMOTRON_H_MOE, LLM_ARCH_EXAONE, LLM_ARCH_EXAONE4, LLM_ARCH_RWKV6, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 4942be7bf6..8191d9f651 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -254,6 +254,24 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { } } +bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= s_copy->ne[0] == mctx->get_n_rs(); + + res &= s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; + + res &= head == mctx->get_head(); + res &= rs_z == mctx->get_rs_z(); + + return res; +} + void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); @@ -461,8 +479,46 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { - inp_attn->set_input(ubatch); - inp_rs->set_input(ubatch); + mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch); + + mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); + res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + + return res; } // @@ -1089,6 +1145,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cur = ggml_relu(ctx0, cur); cb(cur, "ffn_moe_relu", il); } break; + case LLM_FFN_RELU_SQR: + if (gate_exps) { + // TODO: add support for gated squared relu + GGML_ABORT("fatal error: gated squared relu not implemented"); + } else { + cur = ggml_relu(ctx0, cur); + cur = ggml_sqr(ctx0, cur); + cb(cur, "ffn_moe_relu_sqr", il); + } break; default: GGML_ABORT("fatal error"); } @@ -1841,6 +1906,9 @@ static std::unique_ptr build_rs_inp_impl( inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); + inp->head = mctx_cur->get_head(); + inp->rs_z = mctx_cur->get_rs_z(); + return inp; } @@ -1909,10 +1977,10 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { const auto * mctx_cur = static_cast(mctx); - auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); + auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr()); auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); - auto inp = std::make_unique(std::move(inp_attn), std::move(inp_rs), mctx_cur); + auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } diff --git a/src/llama-graph.h b/src/llama-graph.h index e9d387bd7c..81ac329cc3 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -225,6 +225,8 @@ public: void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + ggml_tensor * s_copy; // I32 [n_rs] // views of s_copy, computed once per graph @@ -233,6 +235,10 @@ public: ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs] const llama_memory_recurrent_context * mctx; + + // used in view offsets, need to match for valid graph reuse + uint32_t head; + int32_t rs_z; }; class llm_graph_input_cross_embd : public llm_graph_input_i { @@ -365,22 +371,28 @@ public: class llm_graph_input_mem_hybrid : public llm_graph_input_i { public: llm_graph_input_mem_hybrid( + const llama_cparams & cparams, std::unique_ptr inp_attn, - std::unique_ptr inp_rs, - const llama_memory_hybrid_context * mctx) : + std::unique_ptr inp_rs, + const llama_memory_hybrid_context * mctx) : inp_attn(std::move(inp_attn)), inp_rs(std::move(inp_rs)), + cparams(cparams), mctx(mctx) { } virtual ~llm_graph_input_mem_hybrid() = default; void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + std::unique_ptr inp_attn; std::unique_ptr inp_rs; llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); } llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + const llama_cparams cparams; + const llama_memory_hybrid_context * mctx; }; diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 96c9598c24..83d6d6ee3c 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -2,6 +2,7 @@ #include "ggml.h" +#include #include void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) { diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 30386c157b..061d04ec2e 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1569,9 +1569,11 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id]; + slot_info sinfo; + bool res = true; - res = res && state_read_meta(io, strm, cell_count, seq_id); - res = res && state_read_data(io, strm, cell_count); + res = res && state_read_meta(io, strm, cell_count, sinfo, seq_id); + res = res && state_read_data(io, strm, cell_count, sinfo); if (!res) { if (seq_id == -1) { @@ -1710,7 +1712,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t } } -bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) { +bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id) { auto & cells = v_cells[strm]; auto & head = v_heads[strm]; @@ -1747,7 +1749,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 ubatch.seq_id[i] = &dest_seq_id; } - const auto sinfo = find_slot(ubatch, true); + sinfo = find_slot(ubatch, false); if (sinfo.empty()) { LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return false; @@ -1757,20 +1759,16 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350 apply_ubatch(sinfo, ubatch); - const auto head_cur = sinfo.head(); + LLAMA_LOG_DEBUG("%s: cell_count = %d, dest_seq_id = %d\n", __func__, cell_count, dest_seq_id); - // keep the head at the old position because we will read the KV data into it in state_read_data() - head = head_cur; - - LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id); - - // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values) - // Assume that this is one contiguous block of cells - GGML_ASSERT(head_cur + cell_count <= cells.size()); - GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]); - GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]); - GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id)); - GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id)); + // DEBUG CHECK: verify that all cells were allocated and have correct seq_id and pos values + GGML_ASSERT(sinfo.n_stream() == 1); + GGML_ASSERT(sinfo.idxs[0].size() == cell_count); + for (uint32_t i = 0; i < cell_count; ++i) { + const uint32_t idx = sinfo.idxs[0][i]; + GGML_ASSERT(cells.pos_get(idx) == ubatch.pos[i]); + GGML_ASSERT(cells.seq_has(idx, dest_seq_id)); + } } else { // whole KV cache restore @@ -1803,15 +1801,24 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 } } + // Create contiguous slot_info for whole cache restore + sinfo.s0 = strm; + sinfo.s1 = strm; + sinfo.resize(1); + sinfo.strm[0] = strm; + sinfo.idxs[0].resize(cell_count); + for (uint32_t i = 0; i < cell_count; ++i) { + sinfo.idxs[0][i] = i; + } + head = 0; } return true; } -bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) { +bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo) { auto & cells = v_cells[strm]; - auto & head = v_heads[strm]; uint32_t v_trans; uint32_t n_layer; @@ -1861,8 +1868,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 } if (cell_count) { - // Read and set the keys for the whole cell range - ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + if (sinfo.is_contiguous()) { + // Fast path: contiguous cells, single memcpy + ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row); + } else { + // Slow path: scatter to non-contiguous positions + const void * src = io.read(cell_count * k_size_row); + for (uint32_t i = 0; i < cell_count; ++i) { + const size_t dst_offset = sinfo.idxs[0][i] * k_size_row; + ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row); + } + } } } @@ -1893,8 +1909,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 } if (cell_count) { - // Read and set the values for the whole cell range - ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + if (sinfo.is_contiguous()) { + // Fast path: contiguous cells, single memcpy + ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row); + } else { + // Slow path: scatter to non-contiguous positions + const void * src = io.read(cell_count * v_size_row); + for (uint32_t i = 0; i < cell_count; ++i) { + const size_t dst_offset = sinfo.idxs[0][i] * v_size_row; + ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row); + } + } } } } else { @@ -1933,10 +1958,22 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 } if (cell_count) { - // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (head + j * cells.size()) * v_size_el; - ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + if (sinfo.is_contiguous()) { + // Fast path: contiguous cells + const uint32_t h = sinfo.head(); + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (h + j * cells.size()) * v_size_el; + ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + } + } else { + // Slow path: scatter to non-contiguous positions + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const void * src = io.read(cell_count * v_size_el); + for (uint32_t i = 0; i < cell_count; ++i) { + const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el; + ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el); + } + } } } } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index bf7821c07c..1868f11857 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -72,6 +72,23 @@ public: void clear() { idxs.clear(); } + + // check if indices are contiguous starting from head() + bool is_contiguous() const { + if (idxs.empty() || idxs[0].empty()) { + return true; + } + if (idxs.size() > 1) { + return false; + } + const uint32_t h = idxs[0][0]; + for (size_t i = 0; i < idxs[0].size(); ++i) { + if (idxs[0][i] != h + i) { + return false; + } + } + return true; + } }; using slot_info_vec_t = std::vector; @@ -264,8 +281,8 @@ private: void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const; void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const; - bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1); - bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count); + bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo); }; class llama_kv_cache_context : public llama_memory_context_i { diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index dfb8439e01..a1b45e4a3c 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -222,7 +222,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context( ubatches(std::move(ubatches)), // note: here we copy the ubatches. not sure if this is ideal ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)), - ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), + ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 00bda6a0b1..1143c7a606 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -120,6 +120,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_16B_A1B: return "16B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; + case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; case LLM_TYPE_80B_A3B: return "80B.A3B"; case LLM_TYPE_100B_A6B: return "100B.A6B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; @@ -1802,6 +1803,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } } break; case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: { ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); @@ -1817,7 +1819,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + switch (hparams.n_layer) { + case 52: type = LLM_TYPE_31B_A3_5B; break; // Nemotron-H_MOE 31B case 56: type = LLM_TYPE_9B; break; default: type = LLM_TYPE_UNKNOWN; } @@ -3393,9 +3402,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -5195,6 +5204,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: { // mamba2 Mixer SSM params // NOTE: int64_t for tensor dimensions @@ -5205,6 +5215,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_group = hparams.ssm_n_group; const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp; + // embeddings tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5254,12 +5267,26 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED); layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED); layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - } else { - // mlp layers - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED); + } else { + if (n_expert != 0) { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0); + + // MoE branch + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + + } else { + // mlp layers + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED); + } } } } break; @@ -6886,7 +6913,8 @@ void llama_model::print_info() const { arch == LLM_ARCH_PLAMO2 || arch == LLM_ARCH_GRANITE_HYBRID || arch == LLM_ARCH_QWEN3NEXT || - arch == LLM_ARCH_NEMOTRON_H) { + arch == LLM_ARCH_NEMOTRON_H || + arch == LLM_ARCH_NEMOTRON_H_MOE) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); @@ -6941,7 +6969,8 @@ void llama_model::print_info() const { if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || arch == LLM_ARCH_GRANITE_MOE || - arch == LLM_ARCH_GRANITE_HYBRID) { + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_NEMOTRON_H_MOE) { LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); @@ -7122,7 +7151,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, if (arch == LLM_ARCH_FALCON_H1) { filter_attn = [&](int32_t) { return true; }; filter_recr = [&](int32_t) { return true; }; - } else if (arch == LLM_ARCH_NEMOTRON_H) { + } else if (arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { filter_attn = [&](int32_t il) { return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0; }; @@ -7494,6 +7523,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: { llm = std::make_unique(*this, params); } break; @@ -7778,6 +7808,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ARWKV7: case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values diff --git a/src/llama-model.h b/src/llama-model.h index f8342cf2cb..c6eb953188 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -113,6 +113,7 @@ enum llm_type { LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, + LLM_TYPE_31B_A3_5B, LLM_TYPE_80B_A3B, // Qwen3 Next LLM_TYPE_100B_A6B, LLM_TYPE_106B_A12B, // GLM-4.5-Air diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index b59e821612..11f21ba17a 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1895,7 +1895,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { clean_spaces = false; } else if ( tokenizer_pre == "qwen2" || - tokenizer_pre == "deepseek-r1-qwen") { + tokenizer_pre == "deepseek-r1-qwen" || + tokenizer_pre == "kormo") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; clean_spaces = false; } else if ( diff --git a/src/models/nemotron-h.cpp b/src/models/nemotron-h.cpp index 5414348888..eb135e63f1 100644 --- a/src/models/nemotron-h.cpp +++ b/src/models/nemotron-h.cpp @@ -107,12 +107,41 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * } ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il) { - cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - NULL, NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, - NULL, LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); + if (model.layers[il].ffn_gate_inp == nullptr) { + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + ggml_tensor * ffn_inp = cur; + ggml_tensor * moe_out = + build_moe_ffn(ffn_inp, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + nullptr, // no gate + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_RELU_SQR, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, + il); + cb(moe_out, "ffn_moe_out", il); + + ggml_tensor * ffn_shexp = build_ffn(ffn_inp, + model.layers[il].ffn_up_shexp, NULL, NULL, + NULL /* no gate */ , NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } cur = build_cvec(cur, il); cb(cur, "l_out", il); diff --git a/src/models/qwen2.cpp b/src/models/qwen2.cpp index 587a932426..3da4dea3c1 100644 --- a/src/models/qwen2.cpp +++ b/src/models/qwen2.cpp @@ -31,16 +31,25 @@ llm_build_qwen2::llm_build_qwen2(const llama_model & model, const llm_graph_para { // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9ba559c8df..c3d9f9c324 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -222,6 +222,14 @@ llama_build_and_test(test-backend-ops.cpp) llama_build_and_test(test-model-load-cancel.cpp LABEL "model") llama_build_and_test(test-autorelease.cpp LABEL "model") +# Test for state restore with fragmented KV cache +# Requires a model, uses same args pattern as test-thread-safety +if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") + llama_build_and_test(test-state-restore-fragmented.cpp LABEL "model" ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf) +else() + llama_build_and_test(test-state-restore-fragmented.cpp LABEL "model" ARGS -hf ggml-org/models -hff tinyllamas/stories15M-be.Q4_0.gguf) +endif() + if (NOT GGML_BACKEND_DL) # these tests use the backends directly and cannot be built with dynamic loading llama_build_and_test(test-barrier.cpp) diff --git a/tests/test-state-restore-fragmented.cpp b/tests/test-state-restore-fragmented.cpp new file mode 100644 index 0000000000..481b39d04c --- /dev/null +++ b/tests/test-state-restore-fragmented.cpp @@ -0,0 +1,122 @@ +// Test for state restore with fragmented KV cache +// This tests the fix for: https://github.com/ggml-org/llama.cpp/issues/17527 +// The issue was that state restore required contiguous KV cache slots, +// which fails when the cache is fragmented. +// +// The fix changes find_slot(ubatch, true) to find_slot(ubatch, false) +// in state_read_meta(), allowing non-contiguous slot allocation. + +#include "arg.h" +#include "common.h" +#include "llama.h" + +#include +#include +#include + +int main(int argc, char ** argv) { + common_params params; + + params.sampling.seed = 1234; + params.kv_unified = true; + params.n_parallel = 3; + params.n_ctx = 256; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { + return 1; + } + + common_init(); + + // init + common_init_result_ptr llama_init = common_init_from_params(params); + + llama_model * model = llama_init->model(); + llama_context * ctx = llama_init->context(); + + if (model == nullptr || ctx == nullptr) { + fprintf(stderr, "%s : failed to init\n", __func__); + return 1; + } + + GGML_UNUSED(model); + + // tokenize prompt + std::vector tokens(70, 1); + + // interleave the 3 sequences: + // 01201230123... + llama_batch batch = llama_batch_init(params.n_parallel*tokens.size(), 0, 1); + for (size_t i = 0; i < tokens.size(); i++) { + for (int s = 0; s < params.n_parallel; ++s) { + common_batch_add(batch, tokens[i], i, {s}, false); + } + } + batch.logits[batch.n_tokens - 1] = true; + + if (llama_decode(ctx, batch)) { + fprintf(stderr, "%s : failed to decode seq 0\n", __func__); + return 1; + } + + fprintf(stderr, "%s : processed prompt on seq 0, 1, 2 (%zu tokens each)\n", __func__, tokens.size()); + + // Save state of seq 1 + std::vector seq_state(llama_state_seq_get_size(ctx, 1)); + const size_t ncopy = llama_state_seq_get_data(ctx, seq_state.data(), seq_state.size(), 1); + if (ncopy != seq_state.size()) { + fprintf(stderr, "%s : failed to save seq 1 state\n", __func__); + return 1; + } + fprintf(stderr, "%s : saved seq 1 state, %zu bytes\n", __func__, ncopy); + + // clear seq 1 to create a "hole" in the KV cache (fragmentation) + // 0.20.20.20.2.... + llama_memory_t mem = llama_get_memory(ctx); + llama_memory_seq_rm(mem, 1, -1, -1); + fprintf(stderr, "%s : cleared seq 1 to create fragmentation\n", __func__); + + // Now the cache has holes where seq 1 was + // This creates fragmentation - there's no contiguous block large enough + // for the seq 1 state if we only look for contiguous slots + + // Restore seq 1 state into seq 1 (should work with non-contiguous allocation) + // We use seq 1 since it's a valid sequence ID (0 to n_parallel-1) + // Before the fix, this would fail with "failed to find available cells in kv cache" + const size_t nset = llama_state_seq_set_data(ctx, seq_state.data(), seq_state.size(), 1); + if (nset != seq_state.size()) { + fprintf(stderr, "%s : FAILED to restore seq state into fragmented cache (got %zu, expected %zu)\n", + __func__, nset, seq_state.size()); + fprintf(stderr, "%s : This is the bug - state restore fails with fragmented KV cache\n", __func__); + llama_batch_free(batch); + return 1; + } + fprintf(stderr, "%s : restored state into seq 1, %zu bytes\n", __func__, nset); + + // Verify we can decode with the restored state + // Generate one token to verify the restored state is usable + auto sparams = llama_sampler_chain_default_params(); + llama_sampler * smpl = llama_sampler_chain_init(sparams); + llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sampling.seed)); + + auto next_token = llama_sampler_sample(smpl, ctx, -1); + auto next_token_str = common_token_to_piece(ctx, next_token); + + common_batch_clear(batch); + common_batch_add(batch, next_token, (int)tokens.size(), {1}, true); + + if (llama_decode(ctx, batch)) { + fprintf(stderr, "%s : failed to decode with restored state\n", __func__); + llama_sampler_free(smpl); + llama_batch_free(batch); + return 1; + } + + fprintf(stderr, "%s : successfully decoded with restored state, generated: '%s'\n", __func__, next_token_str.c_str()); + fprintf(stderr, "%s : SUCCESS - state restore works with fragmented KV cache\n", __func__); + + llama_sampler_free(smpl); + llama_batch_free(batch); + + return 0; +} diff --git a/tools/cli/README.md b/tools/cli/README.md new file mode 100644 index 0000000000..1333ed77b7 --- /dev/null +++ b/tools/cli/README.md @@ -0,0 +1 @@ +TODO diff --git a/tools/completion/README.md b/tools/completion/README.md index 54e582de07..57ef394213 100644 --- a/tools/completion/README.md +++ b/tools/completion/README.md @@ -1,4 +1,4 @@ -# llama.cpp/tools/main +# llama.cpp/tools/completion This example program allows you to use various LLaMA language models easily and efficiently. It is specifically designed to work with the [llama.cpp](https://github.com/ggml-org/llama.cpp) project, which provides a plain C/C++ implementation with optional 4-bit quantization support for faster, lower memory inference, and is optimized for desktop CPUs. This program can be used to perform various inference tasks with LLaMA models, including generating text based on user-provided prompts and chat-like interactions with reverse prompts. @@ -27,64 +27,64 @@ Once downloaded, place your model in the models folder in llama.cpp. ##### Input prompt (One-and-done) ```bash -./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf -no-cnv --prompt "Once upon a time" +./llama-completion -m models/gemma-1.1-7b-it.Q4_K_M.gguf -no-cnv --prompt "Once upon a time" ``` ##### Conversation mode (Allow for continuous interaction with the model) ```bash -./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --chat-template gemma +./llama-completion -m models/gemma-1.1-7b-it.Q4_K_M.gguf --chat-template gemma ``` ##### Conversation mode using built-in jinja chat template ```bash -./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --jinja +./llama-completion -m models/gemma-1.1-7b-it.Q4_K_M.gguf --jinja ``` ##### One-and-done query using jinja with custom system prompt and a starting prompt ```bash -./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --jinja --single-turn -sys "You are a helpful assistant" -p "Hello" +./llama-completion -m models/gemma-1.1-7b-it.Q4_K_M.gguf --jinja --single-turn -sys "You are a helpful assistant" -p "Hello" ``` ##### Infinite text from a starting prompt (you can use `Ctrl-C` to stop it): ```bash -./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1 +./llama-completion -m models/gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1 ``` ### Windows: ##### Input prompt (One-and-done) ```powershell -./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf -no-cnv --prompt "Once upon a time" +./llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf -no-cnv --prompt "Once upon a time" ``` ##### Conversation mode (Allow for continuous interaction with the model) ```powershell -./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --chat-template gemma +./llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --chat-template gemma ``` ##### Conversation mode using built-in jinja chat template ```powershell -./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --jinja +./llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --jinja ``` ##### One-and-done query using jinja with custom system prompt and a starting prompt ```powershell -./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --jinja --single-turn -sys "You are a helpful assistant" -p "Hello" +./llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --jinja --single-turn -sys "You are a helpful assistant" -p "Hello" ``` #### Infinite text from a starting prompt (you can use `Ctrl-C` to stop it): ```powershell -llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1 +llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1 ``` ## Common Options -In this section, we cover the most commonly used options for running the `llama-cli` program with the LLaMA models: +In this section, we cover the most commonly used options for running the `llama-completion` program with the LLaMA models: - `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/gemma-1.1-7b-it.Q4_K_M.gguf`; inferred from `--model-url` if set). - `-mu MODEL_URL --model-url MODEL_URL`: Specify a remote http url to download the file (e.g [https://huggingface.co/ggml-org/gemma-1.1-7b-it-Q4_K_M-GGUF/resolve/main/gemma-1.1-7b-it.Q4_K_M.gguf?download=true](https://huggingface.co/ggml-org/gemma-1.1-7b-it-Q4_K_M-GGUF/resolve/main/gemma-1.1-7b-it.Q4_K_M.gguf?download=true)). @@ -97,7 +97,7 @@ In this section, we cover the most commonly used options for running the `llama- ## Input Prompts -The `llama-cli` program provides several ways to interact with the LLaMA models using input prompts: +The `llama-completion` program provides several ways to interact with the LLaMA models using input prompts: - `--prompt PROMPT`: Provide a prompt directly as a command-line option. - `--file FNAME`: Provide a file containing a prompt or multiple prompts. @@ -107,7 +107,7 @@ The `llama-cli` program provides several ways to interact with the LLaMA models ## Interaction -The `llama-cli` program offers a seamless way to interact with LLaMA models, allowing users to engage in real-time conversations or provide instructions for specific tasks. The interactive mode can be triggered using various options, including `--interactive` and `--interactive-first`. +The `llama-completion` program offers a seamless way to interact with LLaMA models, allowing users to engage in real-time conversations or provide instructions for specific tasks. The interactive mode can be triggered using various options, including `--interactive` and `--interactive-first`. In interactive mode, users can participate in text generation by injecting their input during the process. Users can press `Ctrl+C` at any time to interject and type their input, followed by pressing `Return` to submit it to the LLaMA model. To submit additional lines without finalizing input, users can end the current line with a backslash (`\`) and continue typing. @@ -136,7 +136,7 @@ To overcome this limitation, you can use the `--in-prefix` flag to add a space o The `--in-prefix` flag is used to add a prefix to your input, primarily, this is used to insert a space after the reverse prompt. Here's an example of how to use the `--in-prefix` flag in conjunction with the `--reverse-prompt` flag: ```sh -./llama-cli -r "User:" --in-prefix " " +./llama-completion -r "User:" --in-prefix " " ``` ### In-Suffix @@ -144,7 +144,7 @@ The `--in-prefix` flag is used to add a prefix to your input, primarily, this is The `--in-suffix` flag is used to add a suffix after your input. This is useful for adding an "Assistant:" prompt after the user's input. It's added after the new-line character (`\n`) that's automatically added to the end of the user's input. Here's an example of how to use the `--in-suffix` flag in conjunction with the `--reverse-prompt` flag: ```sh -./llama-cli -r "User:" --in-prefix " " --in-suffix "Assistant:" +./llama-completion -r "User:" --in-prefix " " --in-suffix "Assistant:" ``` When --in-prefix or --in-suffix options are enabled the chat template ( --chat-template ) is disabled diff --git a/tools/llama-bench/README.md b/tools/llama-bench/README.md index 87d9c0a219..c837bb6d26 100644 --- a/tools/llama-bench/README.md +++ b/tools/llama-bench/README.md @@ -80,7 +80,7 @@ Each test is repeated the number of times given by `-r`, and the results are ave Using the `-d ` option, each test can be run at a specified context depth, prefilling the KV cache with `` tokens. -For a description of the other options, see the [main example](../main/README.md). +For a description of the other options, see the [completion example](../completion/README.md). > [!NOTE] > The measurements with `llama-bench` do not include the times for tokenization and for sampling. diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index 8ca5cf1e39..a927fa53e0 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -70,6 +70,13 @@ struct clip_hparams { int32_t n_mel_bins = 0; // whisper preprocessor int32_t proj_stack_factor = 0; // ultravox + // audio-to-mel preprocessor params + int32_t audio_chunk_len = -1; // in seconds + int32_t audio_sample_rate = -1; + int32_t audio_n_fft = -1; + int32_t audio_window_len = -1; + int32_t audio_hop_len = -1; + // legacy bool has_llava_projector = false; int minicpmv_version = 0; @@ -323,3 +330,5 @@ struct clip_model { || proj_type == PROJECTOR_TYPE_VOXTRAL; } }; + +const clip_hparams * clip_get_hparams(const struct clip_ctx * ctx); diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 8d1d14299a..5e4daa261b 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1174,11 +1174,15 @@ struct clip_model_loader { model.proj_type == PROJECTOR_TYPE_VOXTRAL || model.proj_type == PROJECTOR_TYPE_GLMA; get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack); - if (hparams.n_mel_bins != 128) { - throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__)); - } hparams.ffn_op = FFN_GELU_ERF; log_ffn_op = "gelu_erf"; // temporary solution for logging + + // audio preprocessing params + hparams.audio_chunk_len = 30; // in seconds + hparams.audio_sample_rate = 16000; + hparams.audio_n_fft = 400; + hparams.audio_window_len = 400; + hparams.audio_hop_len = 160; } break; case PROJECTOR_TYPE_DEEPSEEKOCR: { @@ -1227,6 +1231,11 @@ struct clip_model_loader { LOG_INF("\n--- audio hparams ---\n"); LOG_INF("%s: n_mel_bins: %d\n", __func__, hparams.n_mel_bins); LOG_INF("%s: proj_stack_factor: %d\n", __func__, hparams.proj_stack_factor); + LOG_INF("%s: audio_chunk_len: %d\n", __func__, hparams.audio_chunk_len); + LOG_INF("%s: audio_sample_rate: %d\n", __func__, hparams.audio_sample_rate); + LOG_INF("%s: audio_n_fft: %d\n", __func__, hparams.audio_n_fft); + LOG_INF("%s: audio_window_len: %d\n", __func__, hparams.audio_window_len); + LOG_INF("%s: audio_hop_len: %d\n", __func__, hparams.audio_hop_len); } LOG_INF("\n"); LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0); @@ -3983,3 +3992,7 @@ void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel batch->entries.push_back(clip_image_f32_ptr(audio)); batch->is_audio = true; } + +const clip_hparams * clip_get_hparams(const struct clip_ctx * ctx) { + return &ctx->model.hparams; +} diff --git a/tools/mtmd/mtmd-audio.cpp b/tools/mtmd/mtmd-audio.cpp index 4d053895cd..f68829a61a 100644 --- a/tools/mtmd/mtmd-audio.cpp +++ b/tools/mtmd/mtmd-audio.cpp @@ -11,63 +11,149 @@ // most of the code here is copied from whisper.cpp -// align x to upper multiple of n -#define _ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) +constexpr bool DEBUG = false; -namespace whisper_preprocessor { +struct mtmd_audio_mel_filters { + int32_t n_mel; + int32_t n_fft; -#define SIN_COS_N_COUNT WHISPER_N_FFT -namespace { -struct whisper_global_cache { - // In FFT, we frequently use sine and cosine operations with the same values. - // We can use precalculated values to speed up the process. - float sin_vals[SIN_COS_N_COUNT]; - float cos_vals[SIN_COS_N_COUNT]; + std::vector data; +}; - // Hann window (Use cosf to eliminate difference) - // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html - // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 - float hann_window[WHISPER_N_FFT]; +// note: this global cache is shared among all preprocessors +// if we want to use multiple preprocessors at the same time, +// we will need to enclose it in the preprocessor class in the future +static struct mtmd_audio_global_cache { + // precomputed sin/cos table for FFT + std::vector sin_vals; + std::vector cos_vals; - whisper_global_cache() { - fill_sin_cos_table(); - fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window); - } + // hann window + std::vector hann_window; - void fill_sin_cos_table() { - for (int i = 0; i < SIN_COS_N_COUNT; i++) { - double theta = (2 * M_PI * i) / SIN_COS_N_COUNT; + // mel filter bank + mtmd_audio_mel_filters filters; + + void fill_sin_cos_table(int n) { + sin_vals.resize(n); + cos_vals.resize(n); + for (int i = 0; i < n; i++) { + double theta = (2 * M_PI * i) / n; sin_vals[i] = sinf(theta); cos_vals[i] = cosf(theta); } } - void fill_hann_window(int length, bool periodic, float * output) { + void fill_hann_window(int length, bool periodic) { + hann_window.resize(length); int offset = -1; if (periodic) { offset = 0; } for (int i = 0; i < length; i++) { - output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); } } -} global_cache; -} + + // Build mel filterbank matrix [n_mel × n_fft_bins] at runtime. + // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257. + void fill_mel_filterbank_matrix( + int n_mel, + int n_fft, + int sample_rate, // e.g. 16000 + float fmin = 0.0f, // e.g. 0.0 + float fmax = -1.0f, // e.g. sr/2; pass -1 for auto + bool slaney_area_norm = true, + float scale = 1.0f // optional extra scaling; use 1.0f/1000.0f to mimic your code + ) { + GGML_ASSERT(n_mel > 0 && n_fft > 1); + if (fmax <= 0.0f) { + fmax = 0.5f * sample_rate; + } + + // Slaney scale (matches librosa default) + const double min_log_hz = 1000.0; + const double lin_slope = 3 / 200.; + const double min_log_mel = min_log_hz * lin_slope; + const double log_step = log(6.4) / 27.0; + auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double { + return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step; + }; + auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double { + return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step); + }; + + // infer N_fft from n_fft_bins + const double bin_hz_step = double(sample_rate) / double(n_fft); + + // mel grid: n_mel + 2 edges + const double m_lo = hz_to_mel(fmin); + const double m_hi = hz_to_mel(fmax); + std::vector mel_pts(n_mel + 2); + for (int i = 0; i < n_mel + 2; ++i) { + mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1)); + } + + // convert to Hz + std::vector hz_pts(n_mel + 2); + for (int i = 0; i < n_mel + 2; ++i) { + hz_pts[i] = mel_to_hz(mel_pts[i]); + } + + const int n_fft_bins = n_fft / 2 + 1; + + // filterbank + std::vector out(n_mel * n_fft_bins, 0); + for (int m = 0; m < n_mel; ++m) { + const double f_left = hz_pts[m]; + const double f_center = hz_pts[m + 1]; + const double f_right = hz_pts[m + 2]; + + const double denom_l = std::max(1e-30, f_center - f_left); + const double denom_r = std::max(1e-30, f_right - f_center); + const double enorm = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0; + + for (int k = 0; k < n_fft_bins; ++k) { + const double f = k * bin_hz_step; + double w = 0.0; + if (f >= f_left && f <= f_center) { + w = (f - f_left) / denom_l; + } else if (f > f_center && f <= f_right) { + w = (f_right - f) / denom_r; + } + out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale); + } + } + + filters.n_mel = n_mel; + filters.n_fft = n_fft; + filters.data = std::move(out); + + if (DEBUG) { // debug + for (size_t i = 0; i < filters.data.size(); ++i) { + if (filters.data[i] != 0.0f) { + printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f); + } + } + } + } +} g_cache; // naive Discrete Fourier Transform // input is real-valued // output is complex-valued -static void dft(const float* in, int N, float* out) { - const int sin_cos_step = SIN_COS_N_COUNT / N; +static void dft(const float * in, int N, float * out) { + const int n_sin_cos_vals = g_cache.sin_vals.size(); + const int sin_cos_step = n_sin_cos_vals / N; for (int k = 0; k < N; k++) { float re = 0; float im = 0; for (int n = 0; n < N; n++) { - int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N - re += in[n]*global_cache.cos_vals[idx]; // cos(t) - im -= in[n]*global_cache.sin_vals[idx]; // sin(t) + int idx = (k * n * sin_cos_step) % (n_sin_cos_vals); // t = 2*M_PI*k*n/N + re += in[n] * g_cache.cos_vals[idx]; // cos(t) + im -= in[n] * g_cache.sin_vals[idx]; // sin(t) } out[k*2 + 0] = re; @@ -79,7 +165,8 @@ static void dft(const float* in, int N, float* out) { // poor man's implementation - use something better // input is real-valued // output is complex-valued -static void fft(float* in, int N, float* out) { +static void fft(float * in, int N, float * out) { + const int n_sin_cos_vals = g_cache.sin_vals.size(); if (N == 1) { out[0] = in[0]; out[1] = 0; @@ -106,11 +193,11 @@ static void fft(float* in, int N, float* out) { float* odd_fft = even_fft + N; fft(odd, half_N, odd_fft); - const int sin_cos_step = SIN_COS_N_COUNT / N; + const int sin_cos_step = n_sin_cos_vals / N; for (int k = 0; k < half_N; k++) { int idx = k * sin_cos_step; // t = 2*M_PI*k/N - float re = global_cache.cos_vals[idx]; // cos(t) - float im = -global_cache.sin_vals[idx]; // sin(t) + float re = g_cache.cos_vals[idx]; // cos(t) + float im = -g_cache.sin_vals[idx]; // sin(t) float re_odd = odd_fft[2*k + 0]; float im_odd = odd_fft[2*k + 1]; @@ -123,20 +210,34 @@ static void fft(float* in, int N, float* out) { } } +struct filter_params { + int32_t n_mel; + int32_t n_fft_bins; + int32_t hann_window_size; + int32_t hop_length; + int32_t sample_rate; + bool center_padding = false; + float preemph = 0.f; + bool use_natural_log = false; + bool norm_per_feature = false; +}; + static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples, int n_samples, int frame_size, int frame_step, int n_threads, - const whisper_filters & filters, whisper_mel & mel) { + const filter_params & params, mtmd_audio_mel & out) { std::vector fft_in(frame_size * 2, 0.0); std::vector fft_out(frame_size * 2 * 2 * 2); - int n_fft = filters.n_fft; + int n_fft_bins = params.n_fft_bins; int i = ith; - // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist - WHISPER_ASSERT(n_fft == 1 + (frame_size / 2)); + const auto & filters = g_cache.filters; + // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist + GGML_ASSERT(n_fft_bins == 1 + (frame_size / 2)); + GGML_ASSERT(g_cache.sin_vals.size() == g_cache.cos_vals.size()); // calculate FFT only when fft_in are not all zero - for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) { + for (; i < std::min(n_samples / frame_step + 1, out.n_len); i += n_threads) { const int offset = i * frame_step; // apply Hann window (~10% faster) @@ -154,36 +255,39 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const // Calculate modulus^2 of complex numbers // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. - for (int j = 0; j < n_fft; j++) { + for (int j = 0; j < n_fft_bins; j++) { fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]); } // mel spectrogram - for (int j = 0; j < mel.n_mel; j++) { + for (int j = 0; j < out.n_mel; j++) { double sum = 0.0; // unroll loop (suggested by GH user @lunixbochs) int k = 0; - for (k = 0; k < n_fft - 3; k += 4) { + for (k = 0; k < n_fft_bins - 3; k += 4) { + size_t idx = size_t(j) * size_t(n_fft_bins) + size_t(k); sum += - fft_out[k + 0] * filters.data[j * n_fft + k + 0] + - fft_out[k + 1] * filters.data[j * n_fft + k + 1] + - fft_out[k + 2] * filters.data[j * n_fft + k + 2] + - fft_out[k + 3] * filters.data[j * n_fft + k + 3]; + fft_out[k + 0] * filters.data[idx + 0] + + fft_out[k + 1] * filters.data[idx + 1] + + fft_out[k + 2] * filters.data[idx + 2] + + fft_out[k + 3] * filters.data[idx + 3]; } // handle n_fft remainder - for (; k < n_fft; k++) { - sum += fft_out[k] * filters.data[j * n_fft + k]; + for (; k < n_fft_bins; k++) { + sum += fft_out[k] * filters.data[j * n_fft_bins + k]; } - sum = log10(std::max(sum, 1e-10)); - mel.data[j * mel.n_len + i] = sum; + sum = params.use_natural_log + ? log(sum + 5.960464477539063e-08) + : log10(std::max(sum, 1e-10)); + out.data[j * out.n_len + i] = sum; } } // Otherwise fft_out are all zero - double sum = log10(1e-10); - for (; i < mel.n_len; i += n_threads) { - for (int j = 0; j < mel.n_mel; j++) { - mel.data[j * mel.n_len + i] = sum; + double sum = params.use_natural_log ? log(1e-10) : log10(1e-10); + for (; i < out.n_len; i += n_threads) { + for (int j = 0; j < out.n_mel; j++) { + out.data[j * out.n_len + i] = sum; } } } @@ -191,115 +295,212 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157 static bool log_mel_spectrogram( const float * samples, - const int n_samples, - const int /*sample_rate*/, - const int frame_size, - const int frame_step, - const int n_mel, - const int n_threads, - const whisper_filters & filters, - const bool debug, - whisper_mel & mel) { + const int n_samples_in, + const int n_threads, + const filter_params & params, + mtmd_audio_mel & out) { //const int64_t t_start_us = ggml_time_us(); + out.n_len_org = n_samples_in; + int n_samples = n_samples_in; + // Hann window - WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size"); - const float * hann = global_cache.hann_window; + const float * hann = g_cache.hann_window.data(); + const int frame_size = (params.n_fft_bins - 1) * 2; + const int frame_step = params.hop_length; - // Calculate the length of padding - int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30; - int64_t stage_2_pad = frame_size / 2; - - // Initialize a vector and copy data from C array to it. + // Padding std::vector samples_padded; - samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2); - std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad); + if (params.center_padding) { + const auto pad_amount = frame_size / 2; + samples_padded = std::vector(n_samples + 2 * pad_amount, 0); + std::copy(samples, samples + n_samples, samples_padded.data() + pad_amount); + samples = samples_padded.data(); + n_samples = samples_padded.size(); + } else { + // existing padding logic + int64_t stage_1_pad = params.sample_rate * 30; + int64_t stage_2_pad = frame_size / 2; + samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2); + std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad); + // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio + std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0); + // reflective pad 200 samples at the beginning of audio + if (n_samples < stage_2_pad + 1) { + // TODO: Handle short audio differently or return error + return false; + } + std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin()); + } - // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio - std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0); + // preemphasis + if (params.preemph) { + const int pad_amount = frame_size / 2; + const float preemph = 0.97f; + float prev = samples_padded[pad_amount]; + for (int i = pad_amount + 1; i + pad_amount < n_samples; ++i) { + float cur = samples_padded[i]; + samples_padded[i] = cur - preemph * prev; + prev = cur; + } + } - // reflective pad 200 samples at the beginning of audio - std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin()); + // pad hann window if it's smaller than frame_size + // TODO: probably unnecessary here? (or better doing it in g_cache?) + std::vector hann_window_padded; + if (params.hann_window_size < frame_size) { + hann_window_padded.resize(frame_size); + const int padding = (frame_size - params.hann_window_size) / 2; + std::copy(hann, hann + params.hann_window_size, &hann_window_padded[padding]); + hann = hann_window_padded.data(); + } - mel.n_mel = n_mel; - // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936 - // Calculate number of frames + remove the last frame - mel.n_len = (samples_padded.size() - frame_size) / frame_step; - // Calculate semi-padded sample length to ensure compatibility - mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step; - mel.data.resize(mel.n_mel * mel.n_len); + + out.n_mel = params.n_mel; + out.n_len = (n_samples - frame_size) / frame_step + 1; + // TODO: handle these checks better + if (out.n_mel > 0 && (unsigned long)out.n_len > SIZE_MAX / out.n_mel) { + LOG_ERR("%s: size overflow\n", __func__); + return false; + } + if (n_samples < frame_size) { + LOG_ERR("%s: not enough samples after padding\n", __func__); + return false; + } + out.data.resize(out.n_mel * out.n_len); { std::vector workers(n_threads - 1); for (int iw = 0; iw < n_threads - 1; ++iw) { workers[iw] = std::thread( log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), - n_samples + stage_2_pad, frame_size, frame_step, n_threads, - std::cref(filters), std::ref(mel)); + n_samples, frame_size, frame_step, n_threads, + std::cref(params), std::ref(out)); } // main thread - log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel); - + log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params, out); for (int iw = 0; iw < n_threads - 1; ++iw) { workers[iw].join(); } } - // clamping and normalization - double mmax = -1e20; - for (int i = 0; i < mel.n_mel*mel.n_len; i++) { - if (mel.data[i] > mmax) { - mmax = mel.data[i]; + const int effective_n_len = n_samples_in / frame_step; + if (params.norm_per_feature) { + for (int i = 0; i < out.n_mel; i++) { + double mean = 0; + for (int j = 0; j < effective_n_len; ++j) { + mean += out.data[i * out.n_len + j]; + } + mean /= effective_n_len; + + double var = 0.0; + for (int j = 0; j < effective_n_len; ++j) { + const double value = out.data[i * out.n_len + j] - mean; + var += value * value; + } + var /= effective_n_len - 1; // unbiased + const double mstd = std::sqrt(var + 1e-5); + + for (int j = 0; j < effective_n_len; ++j) { + auto &value = out.data[i * out.n_len + j]; + value = (value - mean) / mstd; + } + + // pad the rest with zeros + for (int j = effective_n_len; j < out.n_len; ++j) { + out.data[i * out.n_len + j] = 0.0; + } } - } - - mmax -= 8.0; - - for (int i = 0; i < mel.n_mel*mel.n_len; i++) { - if (mel.data[i] < mmax) { - mel.data[i] = mmax; + } else { + // clamping and normalization + double mmax = -1e20; + for (int i = 0; i < out.n_mel*out.n_len; i++) { + if (out.data[i] > mmax) { + mmax = out.data[i]; + } } - mel.data[i] = (mel.data[i] + 4.0)/4.0; + mmax -= 8.0; + + for (int i = 0; i < out.n_mel*out.n_len; i++) { + if (out.data[i] < mmax) { + out.data[i] = mmax; + } + out.data[i] = (out.data[i] + 4.0)/4.0; + } } // Dump log_mel_spectrogram - if (debug) { + if (DEBUG) { std::ofstream outFile("log_mel_spectrogram.json"); outFile << "["; - for (uint64_t i = 0; i < mel.data.size() - 1; i++) { - outFile << mel.data[i] << ", "; + for (uint64_t i = 0; i < out.data.size() - 1; i++) { + outFile << out.data[i] << ", "; } - outFile << mel.data[mel.data.size() - 1] << "]"; + outFile << out.data[out.data.size() - 1] << "]"; outFile.close(); } return true; } -bool preprocess_audio( +// +// mtmd_audio_preprocessor_whisper +// + +void mtmd_audio_preprocessor_whisper::initialize() { + g_cache.fill_sin_cos_table(hparams.audio_n_fft); + g_cache.fill_hann_window(hparams.audio_window_len, true); + g_cache.fill_mel_filterbank_matrix( + hparams.n_mel_bins, + hparams.audio_n_fft, + hparams.audio_sample_rate); +} + +bool mtmd_audio_preprocessor_whisper::preprocess( const float * samples, size_t n_samples, - const whisper_filters & filters, - std::vector & output) { - + std::vector & output) { if (n_samples == 0) { // empty audio return false; } - whisper_mel out_full; + std::vector smpl; + // if input is too short, pad with zeros + // this is to avoid potential issues with stage1/2 padding in log_mel_spectrogram + // TODO: maybe handle this better + size_t min_samples = (size_t)hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin + if (n_samples < min_samples) { + smpl.resize(min_samples, 0.0f); + std::memcpy(smpl.data(), samples, n_samples * sizeof(float)); + samples = smpl.data(); + n_samples = smpl.size(); + } + + filter_params params; + params.n_mel = hparams.n_mel_bins; + params.n_fft_bins = 1 + (hparams.audio_n_fft / 2); + params.hann_window_size = hparams.audio_window_len; + params.hop_length = hparams.audio_hop_len; + params.sample_rate = hparams.audio_sample_rate; + params.center_padding = false; + params.preemph = 0.0f; // disabled + params.use_natural_log = false; + params.norm_per_feature = false; + + // make sure the global cache is initialized + GGML_ASSERT(!g_cache.sin_vals.empty()); + GGML_ASSERT(!g_cache.cos_vals.empty()); + GGML_ASSERT(!g_cache.filters.data.empty()); + + mtmd_audio_mel out_full; bool ok = log_mel_spectrogram( samples, n_samples, - COMMON_SAMPLE_RATE, - WHISPER_N_FFT, - WHISPER_HOP_LENGTH, - filters.n_mel, 4, // n_threads - filters, - false, // debug + params, out_full); if (!ok) { return false; @@ -307,7 +508,9 @@ bool preprocess_audio( // because the cgraph in clip.cpp only accepts 3000 frames each, we need to split the mel // we always expect the mel to have 3000 silent frames at the end - // printf("n_len %d\n", out_full.n_len); + if (DEBUG) { + printf("output: n_mel = %d, n_len = %d\n", out_full.n_mel, out_full.n_len); + } const size_t frames_per_chunk = 3000; GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk); for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) { @@ -316,7 +519,7 @@ bool preprocess_audio( break; // last uncomplete chunk will always be a padded chunk, safe to ignore } - whisper_mel out_chunk; + mtmd_audio_mel out_chunk; out_chunk.n_len = n_len; out_chunk.n_mel = out_full.n_mel; out_chunk.n_len_org = out_full.n_mel; // unused @@ -332,438 +535,3 @@ bool preprocess_audio( return true; } - -} // namespace whisper_preprocessor - - -// precalculated mel filter banks -// values are multiplied by 1000.0 to save space, and will be divided by 1000.0 in the end of the function -// -// generated from python code: -// -// from numpy import load -// data = load('mel_filters.npz') -// lst = data.files -// for item in lst: -// print(item) -// print(data[item].shape) -// n_mel = data[item].shape[0] -// n_fft = data[item].shape[1] -// for i, row in enumerate(data[item]): -// for j, val in enumerate(row): -// val = val * 1000.0 -// if val != 0: -// print(f"data[{i*n_fft + j}] = {val:.6f};") - -namespace whisper_precalc_filters { - -whisper_preprocessor::whisper_filters get_128_bins() { - whisper_preprocessor::whisper_filters filters; - filters.n_mel = 128; - filters.n_fft = 201; - std::vector data(filters.n_mel * filters.n_fft, 0.0f); - - data[1] = 12.37398665; - data[202] = 30.39256483; - data[404] = 24.74797331; - data[605] = 18.01857911; - data[807] = 37.12195903; - data[1008] = 5.64459199; - data[1009] = 6.72939420; - data[1210] = 36.03715822; - data[1412] = 19.10337992; - data[1613] = 23.66316877; - data[1815] = 31.47736564; - data[2016] = 11.28918398; - data[2017] = 1.08480197; - data[2218] = 41.68175161; - data[2420] = 13.45878839; - data[2621] = 29.30776216; - data[2823] = 25.83277412; - data[3024] = 16.93377644; - data[3226] = 38.20675984; - data[3427] = 4.55979025; - data[3428] = 7.81419594; - data[3629] = 34.95235741; - data[3831] = 20.18818259; - data[4032] = 22.57836796; - data[4234] = 32.56217018; - data[4435] = 10.20438317; - data[4436] = 2.16960395; - data[4637] = 40.59694707; - data[4839] = 14.54358920; - data[5040] = 28.22295949; - data[5242] = 26.91757679; - data[5443] = 15.84897563; - data[5645] = 39.29156065; - data[5846] = 3.47498828; - data[5847] = 8.89899861; - data[6048] = 33.86755288; - data[6250] = 21.27298526; - data[6451] = 21.49356715; - data[6653] = 33.64697099; - data[6854] = 9.11958050; - data[6855] = 3.25440569; - data[7056] = 39.51214626; - data[7258] = 15.62839188; - data[7459] = 27.13815868; - data[7661] = 28.00237760; - data[7862] = 14.76417296; - data[8064] = 40.37636518; - data[8265] = 2.38068704; - data[8266] = 10.20263787; - data[8467] = 31.61146119; - data[8669] = 24.54700135; - data[8870] = 15.32919332; - data[8871] = 1.66583748; - data[9072] = 36.72905266; - data[9274] = 20.09709924; - data[9475] = 16.93102531; - data[9476] = 2.90265540; - data[9677] = 32.84499049; - data[9879] = 23.52004871; - data[10080] = 11.03894413; - data[10081] = 10.72582975; - data[10282] = 22.71829173; - data[10484] = 32.27872774; - data[10685] = 0.11626833; - data[10686] = 22.85348251; - data[10887] = 8.56344029; - data[10888] = 14.97978810; - data[11089] = 15.51398356; - data[11090] = 8.51490628; - data[11291] = 21.10680379; - data[11292] = 3.32652032; - data[11493] = 25.47064796; - data[11695] = 27.35907957; - data[11896] = 0.65853616; - data[11897] = 23.83812517; - data[12098] = 3.44359246; - data[12099] = 21.22455277; - data[12300] = 5.35842171; - data[12301] = 19.42555793; - data[12502] = 6.49324711; - data[12503] = 18.35542172; - data[12704] = 6.93138083; - data[12705] = 17.93504693; - data[12906] = 6.74968259; - data[12907] = 18.09151843; - data[13108] = 6.01899112; - data[13109] = 18.75767298; - data[13310] = 4.80452832; - data[13311] = 19.87172849; - data[13512] = 3.16627859; - data[13513] = 21.37690969; - data[13514] = 1.25317345; - data[13714] = 1.15934468; - data[13715] = 20.80361731; - data[13716] = 4.04486805; - data[13917] = 17.55363122; - data[13918] = 7.08320038; - data[14119] = 14.07538634; - data[14120] = 10.32655034; - data[14321] = 10.40921453; - data[14322] = 13.73696327; - data[14523] = 6.59187697; - data[14524] = 17.27988198; - data[14525] = 1.46804214; - data[14725] = 2.65681883; - data[14726] = 18.09193194; - data[14727] = 5.85655728; - data[14928] = 13.34277913; - data[14929] = 10.28267574; - data[15130] = 8.56800377; - data[15131] = 14.72230814; - data[15132] = 1.04039861; - data[15332] = 3.79085587; - data[15333] = 17.14678481; - data[15334] = 6.11609267; - data[15535] = 11.75929047; - data[15536] = 11.13393717; - data[15737] = 6.43857848; - data[15738] = 16.07806236; - data[15739] = 4.23917221; - data[15939] = 1.19989377; - data[15940] = 12.75671553; - data[15941] = 9.65298992; - data[16142] = 7.06935255; - data[16143] = 14.94054683; - data[16144] = 4.19024844; - data[16344] = 1.51483389; - data[16345] = 12.00899947; - data[16346] = 9.84823331; - data[16547] = 6.10224018; - data[16548] = 15.33857174; - data[16549] = 5.57676842; - data[16749] = 0.36827257; - data[16750] = 9.89749376; - data[16751] = 11.35340426; - data[16752] = 2.05122307; - data[16952] = 3.89297144; - data[16953] = 12.97352277; - data[16954] = 8.06631614; - data[17155] = 6.74493238; - data[17156] = 13.85874674; - data[17157] = 5.41190524; - data[17357] = 0.74220158; - data[17358] = 8.98779090; - data[17359] = 11.37871388; - data[17360] = 3.32958088; - data[17560] = 2.82313535; - data[17561] = 10.68049297; - data[17562] = 9.43340641; - data[17563] = 1.76325557; - data[17763] = 4.39018616; - data[17764] = 11.87758986; - data[17765] = 7.97005836; - data[17766] = 0.66104700; - data[17966] = 5.49466675; - data[17967] = 12.62953598; - data[17968] = 6.93987962; - data[18169] = 6.18401915; - data[18170] = 12.93473132; - data[18171] = 6.29778765; - data[18371] = 0.02325210; - data[18372] = 6.50206627; - data[18373] = 12.32661773; - data[18374] = 6.00216538; - data[18574] = 0.31548753; - data[18575] = 6.48925547; - data[18576] = 12.04130240; - data[18577] = 6.01462880; - data[18777] = 0.29979556; - data[18778] = 6.18288014; - data[18779] = 12.04272825; - data[18780] = 6.29981188; - data[18781] = 0.55689598; - data[18980] = 0.01120471; - data[18981] = 5.61729167; - data[18982] = 11.22337859; - data[18983] = 6.82516303; - data[18984] = 1.35264499; - data[19184] = 4.82410006; - data[19185] = 10.16623247; - data[19186] = 7.56075513; - data[19187] = 2.34590308; - data[19387] = 3.83235747; - data[19388] = 8.92296247; - data[19389] = 8.47910438; - data[19390] = 3.50978645; - data[19590] = 2.66873185; - data[19591] = 7.51965167; - data[19592] = 9.55500547; - data[19593] = 4.81966138; - data[19594] = 0.08431751; - data[19793] = 1.35767367; - data[19794] = 5.98019501; - data[19795] = 10.60271543; - data[19796] = 6.25298498; - data[19797] = 1.74059917; - data[19997] = 4.32644226; - data[19998] = 8.73131864; - data[19999] = 7.78916525; - data[20000] = 3.48923868; - data[20200] = 2.57835095; - data[20201] = 6.77582854; - data[20202] = 9.40941647; - data[20203] = 5.31194592; - data[20204] = 1.21447595; - data[20403] = 0.75411191; - data[20404] = 4.75395704; - data[20405] = 8.75380263; - data[20406] = 7.19209015; - data[20407] = 3.28754401; - data[20607] = 2.68179690; - data[20608] = 6.49331464; - data[20609] = 9.11457930; - data[20610] = 5.39387390; - data[20611] = 1.67316827; - data[20810] = 0.57394296; - data[20811] = 4.20600036; - data[20812] = 7.83805829; - data[20813] = 7.52023002; - data[20814] = 3.97470826; - data[20815] = 0.42918732; - data[21014] = 1.90464477; - data[21015] = 5.36569161; - data[21016] = 8.82673822; - data[21017] = 6.27609482; - data[21018] = 2.89750961; - data[21218] = 2.89885257; - data[21219] = 6.19694078; - data[21220] = 8.56699049; - data[21221] = 5.34748193; - data[21222] = 2.12797290; - data[21421] = 0.44750227; - data[21422] = 3.59030394; - data[21423] = 6.73310598; - data[21424] = 7.77023612; - data[21425] = 4.70231380; - data[21426] = 1.63439126; - data[21625] = 1.01536023; - data[21626] = 4.01018746; - data[21627] = 7.00501446; - data[21628] = 7.23442994; - data[21629] = 4.31095669; - data[21630] = 1.38748321; - data[21829] = 1.33348850; - data[21830] = 4.18730825; - data[21831] = 7.04112789; - data[21832] = 6.93188375; - data[21833] = 4.14605811; - data[21834] = 1.36023236; - data[22033] = 1.42879714; - data[22034] = 4.14824858; - data[22035] = 6.86769979; - data[22036] = 6.83705276; - data[22037] = 4.18239459; - data[22038] = 1.52773573; - data[22237] = 1.32610439; - data[22238] = 3.91751388; - data[22239] = 6.50892360; - data[22240] = 6.92639686; - data[22241] = 4.39672917; - data[22242] = 1.86706171; - data[22441] = 1.04827771; - data[22442] = 3.51767405; - data[22443] = 5.98707050; - data[22444] = 7.17824046; - data[22445] = 4.76767914; - data[22446] = 2.35711760; - data[22645] = 0.61636406; - data[22646] = 2.96949223; - data[22647] = 5.32262027; - data[22648] = 7.57265091; - data[22649] = 5.27558755; - data[22650] = 2.97852419; - data[22651] = 0.68146095; - data[22849] = 0.04971400; - data[22850] = 2.29204819; - data[22851] = 4.53438237; - data[22852] = 6.77671656; - data[22853] = 5.90240723; - data[22854] = 3.71349836; - data[22855] = 1.52458926; - data[23054] = 1.50285335; - data[23055] = 3.63961048; - data[23056] = 5.77636715; - data[23057] = 6.63159089; - data[23058] = 4.54574358; - data[23059] = 2.45989650; - data[23060] = 0.37404924; - data[23258] = 0.61795861; - data[23259] = 2.65410915; - data[23260] = 4.69025923; - data[23261] = 6.72641024; - data[23262] = 5.46034705; - data[23263] = 3.47270933; - data[23264] = 1.48507138; - data[23463] = 1.59233576; - data[23464] = 3.53261665; - data[23465] = 5.47289755; - data[23466] = 6.44368259; - data[23467] = 4.54962999; - data[23468] = 2.65557761; - data[23469] = 0.76152512; - data[23667] = 0.46749352; - data[23668] = 2.31641904; - data[23669] = 4.16534441; - data[23670] = 6.01426978; - data[23671] = 5.67844696; - data[23672] = 3.87357362; - data[23673] = 2.06870004; - data[23674] = 0.26382666; - data[23872] = 1.05349103; - data[23873] = 2.81536230; - data[23874] = 4.57723346; - data[23875] = 6.33910485; - data[23876] = 5.12815686; - data[23877] = 3.40826320; - data[23878] = 1.68837002; - data[24077] = 1.43350090; - data[24078] = 3.11241671; - data[24079] = 4.79133241; - data[24080] = 6.40943693; - data[24081] = 4.77052201; - data[24082] = 3.13160778; - data[24083] = 1.49269309; - data[24281] = 0.02932359; - data[24282] = 1.62918994; - data[24283] = 3.22905602; - data[24284] = 4.82892245; - data[24285] = 6.14671456; - data[24286] = 4.58496623; - data[24287] = 3.02321767; - data[24288] = 1.46146910; - data[24486] = 0.13601698; - data[24487] = 1.66055572; - data[24488] = 3.18509457; - data[24489] = 4.70963307; - data[24490] = 6.04072399; - data[24491] = 4.55250870; - data[24492] = 3.06429295; - data[24493] = 1.57607743; - data[24494] = 0.08786193; - data[24691] = 0.09328097; - data[24692] = 1.54603878; - data[24693] = 2.99879676; - data[24694] = 4.45155473; - data[24695] = 5.90431225; - data[24696] = 4.65566106; - data[24697] = 3.23751615; - data[24698] = 1.81937125; - data[24699] = 0.40122634; - data[24897] = 1.30262633; - data[24898] = 2.68698297; - data[24899] = 4.07133950; - data[24900] = 5.45569602; - data[24901] = 4.87832492; - data[24902] = 3.52695142; - data[24903] = 2.17557792; - data[24904] = 0.82420459; - data[25102] = 0.94595028; - data[25103] = 2.26512621; - data[25104] = 3.58430226; - data[25105] = 4.90347855; - data[25106] = 5.20569785; - data[25107] = 3.91795207; - data[25108] = 2.63020652; - data[25109] = 1.34246063; - data[25110] = 0.05471494; - data[25307] = 0.49037894; - data[25308] = 1.74744334; - data[25309] = 3.00450763; - data[25310] = 4.26157191; - data[25311] = 5.51863620; - data[25312] = 4.39707236; - data[25313] = 3.16995848; - data[25314] = 1.94284460; - data[25315] = 0.71573065; - data[25513] = 1.14698056; - data[25514] = 2.34485767; - data[25515] = 3.54273478; - data[25516] = 4.74061165; - data[25517] = 4.95198462; - data[25518] = 3.78264743; - data[25519] = 2.61331047; - data[25520] = 1.44397374; - data[25521] = 0.27463681; - data[25718] = 0.47569509; - data[25719] = 1.61717169; - data[25720] = 2.75864848; - data[25721] = 3.90012516; - data[25722] = 5.04160160; - data[25723] = 4.45712078; - data[25724] = 3.34284059; - data[25725] = 2.22856039; - data[25726] = 1.11428020; - - for (auto & val : data) { - val /= 1000.0f; - } - - filters.data = std::move(data); - return filters; -} - -} // namespace whisper_precalc_filters diff --git a/tools/mtmd/mtmd-audio.h b/tools/mtmd/mtmd-audio.h index 0e552347a0..1b454337cb 100644 --- a/tools/mtmd/mtmd-audio.h +++ b/tools/mtmd/mtmd-audio.h @@ -1,6 +1,7 @@ #pragma once #include "ggml.h" +#include "clip-model.h" #include #include @@ -8,18 +9,7 @@ #define MTMD_INTERNAL_HEADER -#define WHISPER_ASSERT GGML_ASSERT - -#define WHISPER_SAMPLE_RATE 16000 -#define WHISPER_N_FFT 400 -#define WHISPER_HOP_LENGTH 160 -#define WHISPER_CHUNK_SIZE 30 - -#define COMMON_SAMPLE_RATE 16000 - -namespace whisper_preprocessor { - -struct whisper_mel { +struct mtmd_audio_mel { int n_len; int n_len_org; int n_mel; @@ -27,23 +17,18 @@ struct whisper_mel { std::vector data; }; -struct whisper_filters { - int32_t n_mel; - int32_t n_fft; +struct mtmd_audio_preprocessor { + const clip_hparams & hparams; - std::vector data; + mtmd_audio_preprocessor(const clip_ctx * ctx): hparams(*clip_get_hparams(ctx)) {} + + virtual ~mtmd_audio_preprocessor() = default; + virtual void initialize() = 0; // NOT thread-safe + virtual bool preprocess(const float * samples, size_t n_samples, std::vector & output) = 0; }; -bool preprocess_audio( - const float * samples, - size_t n_samples, - const whisper_filters & filters, - std::vector & output); - -} // namespace whisper_preprocessor - -namespace whisper_precalc_filters { - -whisper_preprocessor::whisper_filters get_128_bins(); - -} // namespace whisper_precalc_filters +struct mtmd_audio_preprocessor_whisper : mtmd_audio_preprocessor { + mtmd_audio_preprocessor_whisper(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {} + void initialize() override; + bool preprocess(const float * samples, size_t n_samples, std::vector & output) override; +}; diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 33042722eb..d82695da01 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -151,8 +151,7 @@ struct mtmd_context { // string template for slice image delimiters with row/col (idefics3) std::string sli_img_start_tmpl; - // for whisper, we pre-calculate the mel filter bank - whisper_preprocessor::whisper_filters w_filters; + std::unique_ptr audio_preproc; // TODO @ngxson : add timings @@ -317,14 +316,25 @@ struct mtmd_context { GGML_ASSERT(ctx_a != nullptr); projector_type proj = clip_get_projector_type(ctx_a); - if (clip_has_whisper_encoder(ctx_a)) { - // TODO @ngxson : check if model n_mel is 128 or 80 - w_filters = whisper_precalc_filters::get_128_bins(); - } - LOG_WRN("%s: audio input is in experimental stage and may have reduced quality:\n" " https://github.com/ggml-org/llama.cpp/discussions/13759\n", __func__); + // set preprocessor + switch (proj) { + case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_QWEN25O: + case PROJECTOR_TYPE_ULTRAVOX: + case PROJECTOR_TYPE_VOXTRAL: + audio_preproc = std::make_unique(ctx_a); + break; + default: + GGML_ABORT("unsupported audio projector type"); + } + + // initialize audio preprocessor + audio_preproc->initialize(); + + // set special tokens if (proj == PROJECTOR_TYPE_QWEN2A) { // <|audio_bos|> ... (embeddings) ... <|audio_eos|> aud_beg = "<|audio_bos|>"; @@ -653,11 +663,10 @@ struct mtmd_tokenizer { } // preprocess audio - GGML_ASSERT(ctx->w_filters.n_mel); // make sure we have filter preloaded - std::vector mel_spec_chunks; + std::vector mel_spec_chunks; const float * samples = (const float *)bitmap->data.data(); size_t n_samples = bitmap->data.size() / sizeof(float); - bool ok = whisper_preprocessor::preprocess_audio(samples, n_samples, ctx->w_filters, mel_spec_chunks); + bool ok = ctx->audio_preproc->preprocess(samples, n_samples, mel_spec_chunks); if (!ok) { LOG_ERR("Unable to preprocess audio\n"); return 2; @@ -864,8 +873,7 @@ int mtmd_get_audio_bitrate(mtmd_context * ctx) { if (!ctx->ctx_a) { return -1; } - // for now, we assume that all audio models have the same bitrate - return 16000; // 16kHz + return clip_get_hparams(ctx->ctx_a)->audio_sample_rate; } // diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 3fd631b77a..2ff90e800a 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/webui/README.md b/tools/server/webui/README.md index d995271fc4..98b01fdcd7 100644 --- a/tools/server/webui/README.md +++ b/tools/server/webui/README.md @@ -619,11 +619,12 @@ flowchart TB ### Test Types -| Type | Tool | Location | Command | -| ------------- | ------------------ | -------------------------------- | ------------------- | -| **E2E** | Playwright | `tests/e2e/` | `npm run test:e2e` | -| **Unit** | Vitest | `tests/client/`, `tests/server/` | `npm run test:unit` | -| **UI/Visual** | Storybook + Vitest | `tests/stories/` | `npm run test:ui` | +| Type | Tool | Location | Command | +| ------------- | ------------------ | ---------------- | ------------------- | +| **Unit** | Vitest | `tests/unit/` | `npm run test:unit` | +| **UI/Visual** | Storybook + Vitest | `tests/stories/` | `npm run test:ui` | +| **E2E** | Playwright | `tests/e2e/` | `npm run test:e2e` | +| **Client** | Vitest | `tests/client/`. | `npm run test:unit` | ### Running Tests diff --git a/tools/server/webui/package.json b/tools/server/webui/package.json index c20ab3cfde..1c970ae7a8 100644 --- a/tools/server/webui/package.json +++ b/tools/server/webui/package.json @@ -13,12 +13,11 @@ "reset": "rm -rf .svelte-kit node_modules", "format": "prettier --write .", "lint": "prettier --check . && eslint .", - "test": "npm run test:ui -- --run && npm run test:client -- --run && npm run test:server -- --run && npm run test:e2e", + "test": "npm run test:ui -- --run && npm run test:client -- --run && npm run test:unit -- --run && npm run test:e2e", "test:e2e": "playwright test", "test:client": "vitest --project=client", - "test:server": "vitest --project=server", + "test:unit": "vitest --project=unit", "test:ui": "vitest --project=ui", - "test:unit": "vitest", "storybook": "storybook dev -p 6006", "build-storybook": "storybook build", "cleanup": "rm -rf .svelte-kit build node_modules test-results" diff --git a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreview.svelte b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreview.svelte index b5fe3fa9c4..0b0bf52ad9 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreview.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreview.svelte @@ -241,7 +241,7 @@ {/if} {:else if (isText || (isPdf && pdfViewMode === 'text')) && displayTextContent} - + {:else if isAudio}
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentThumbnailFile.svelte b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentThumbnailFile.svelte index 6fdd857214..908db5894b 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentThumbnailFile.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentThumbnailFile.svelte @@ -1,6 +1,6 @@
@@ -229,6 +264,25 @@
{/if}
+ +
+

Delete All Conversations

+ +

+ Permanently delete all conversations and their messages. This action cannot be undone. + Consider exporting your conversations first if you want to keep a backup. +

+ + +
@@ -249,3 +303,15 @@ onCancel={() => (showImportDialog = false)} onConfirm={handleImportConfirm} /> + + diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebar.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebar.svelte index 1d313e284e..aa0c27f6d3 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebar.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebar.svelte @@ -9,6 +9,7 @@ import Input from '$lib/components/ui/input/input.svelte'; import { conversationsStore, conversations } from '$lib/stores/conversations.svelte'; import { chatStore } from '$lib/stores/chat.svelte'; + import { getPreviewText } from '$lib/utils/text'; import ChatSidebarActions from './ChatSidebarActions.svelte'; const sidebar = Sidebar.useSidebar(); @@ -20,6 +21,9 @@ let showEditDialog = $state(false); let selectedConversation = $state(null); let editedName = $state(''); + let selectedConversationNamePreview = $derived.by(() => + selectedConversation ? getPreviewText(selectedConversation.name) : '' + ); let filteredConversations = $derived.by(() => { if (searchQuery.trim().length > 0) { @@ -162,7 +166,7 @@ bind:open={showDeleteDialog} title="Delete Conversation" description={selectedConversation - ? `Are you sure you want to delete "${selectedConversation.name}"? This action cannot be undone and will permanently remove all messages in this conversation.` + ? `Are you sure you want to delete "${selectedConversationNamePreview}"? This action cannot be undone and will permanently remove all messages in this conversation.` : ''} confirmText="Delete" cancelText="Cancel" diff --git a/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte b/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte index 9c37bde0d4..2a4a39535e 100644 --- a/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte +++ b/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte @@ -504,6 +504,14 @@ background: hsl(var(--muted) / 0.1); } + /* User message markdown should keep table borders visible on light primary backgrounds */ + div.markdown-user-content :global(table), + div.markdown-user-content :global(th), + div.markdown-user-content :global(td), + div.markdown-user-content :global(.table-wrapper) { + border-color: currentColor; + } + /* Horizontal rules */ div :global(hr) { border: none; @@ -642,6 +650,21 @@ background: var(--muted); } + /* Disable hover effects when rendering user messages */ + .markdown-user-content :global(a), + .markdown-user-content :global(a:hover) { + color: var(--primary-foreground); + } + + .markdown-user-content :global(table:hover) { + box-shadow: none; + } + + .markdown-user-content :global(th:hover), + .markdown-user-content :global(td:hover) { + background: inherit; + } + /* Enhanced blockquotes */ div :global(blockquote) { transition: all 0.2s ease; diff --git a/tools/server/webui/src/lib/components/app/misc/SyntaxHighlightedCode.svelte b/tools/server/webui/src/lib/components/app/misc/SyntaxHighlightedCode.svelte index f36a9a20b9..bc42f9dd1e 100644 --- a/tools/server/webui/src/lib/components/app/misc/SyntaxHighlightedCode.svelte +++ b/tools/server/webui/src/lib/components/app/misc/SyntaxHighlightedCode.svelte @@ -72,9 +72,10 @@
-

+	
{@html highlightedHtml}
diff --git a/tools/server/webui/src/lib/components/app/models/ModelsSelector.svelte b/tools/server/webui/src/lib/components/app/models/ModelsSelector.svelte index ac0937696d..efc9cd4e2f 100644 --- a/tools/server/webui/src/lib/components/app/models/ModelsSelector.svelte +++ b/tools/server/webui/src/lib/components/app/models/ModelsSelector.svelte @@ -179,51 +179,37 @@ }); }); + // Handle changes to the model selector pop-down or the model dialog, depending on if the server is in + // router mode or not. function handleOpenChange(open: boolean) { if (loading || updating) return; - if (open) { - isOpen = true; - searchTerm = ''; - highlightedIndex = -1; + if (isRouter) { + if (open) { + isOpen = true; + searchTerm = ''; + highlightedIndex = -1; - // Focus search input after popover opens - tick().then(() => { - requestAnimationFrame(() => searchInputRef?.focus()); - }); + // Focus search input after popover opens + tick().then(() => { + requestAnimationFrame(() => searchInputRef?.focus()); + }); - if (isRouter) { modelsStore.fetchRouterModels().then(() => { modelsStore.fetchModalitiesForLoadedModels(); }); + } else { + isOpen = false; + searchTerm = ''; + highlightedIndex = -1; } } else { - isOpen = false; - searchTerm = ''; - highlightedIndex = -1; + showModelDialog = open; } } - function handleTriggerClick() { - if (loading || updating) return; - - if (!isRouter) { - // Single model mode: show dialog instead of popover - showModelDialog = true; - } - // For router mode, the Popover handles open/close - } - export function open() { - if (isRouter) { - handleOpenChange(true); - } else { - showModelDialog = true; - } - } - - function closeMenu() { - handleOpenChange(false); + handleOpenChange(true); } function handleSearchKeyDown(event: KeyboardEvent) { @@ -292,7 +278,7 @@ } if (shouldCloseMenu) { - closeMenu(); + handleOpenChange(false); // Focus the chat textarea after model selection requestAnimationFrame(() => { @@ -360,8 +346,181 @@ {:else} {@const selectedOption = getDisplayOption()} - - + + + + + {selectedOption?.model || 'Select model'} + + + {#if updating} + + {:else} + + {/if} + + + +
+
+ handleOpenChange(false)} + onKeyDown={handleSearchKeyDown} + /> +
+
+ {#if !isCurrentModelInCache() && currentModel} + + +
+ {/if} + {#if filteredOptions.length === 0} +

No models found.

+ {/if} + {#each filteredOptions as option, index (option.id)} + {@const status = getModelStatus(option.model)} + {@const isLoaded = status === ServerModelStatus.LOADED} + {@const isLoading = status === ServerModelStatus.LOADING} + {@const isSelected = currentModel === option.model || activeId === option.id} + {@const isCompatible = isModelCompatible(option)} + {@const isHighlighted = index === highlightedIndex} + {@const missingModalities = getMissingModalities(option)} + +
isCompatible && handleSelect(option.id)} + onmouseenter={() => (highlightedIndex = index)} + onkeydown={(e) => { + if (isCompatible && (e.key === 'Enter' || e.key === ' ')) { + e.preventDefault(); + handleSelect(option.id); + } + }} + > + {option.model} + + {#if missingModalities} + + {#if missingModalities.vision} + + + + + +

No vision support

+
+
+ {/if} + {#if missingModalities.audio} + + + + + +

No audio support

+
+
+ {/if} +
+ {/if} + + {#if isLoading} + + + + + +

Loading model...

+
+
+ {:else if isLoaded} + + + + + +

Unload model

+
+
+ {:else} + + {/if} +
+ {/each} +
+
+
+
+ {:else} + -
- {/if} - {#if filteredOptions.length === 0} -

No models found.

- {/if} - {#each filteredOptions as option, index (option.id)} - {@const status = getModelStatus(option.model)} - {@const isLoaded = status === ServerModelStatus.LOADED} - {@const isLoading = status === ServerModelStatus.LOADING} - {@const isSelected = currentModel === option.model || activeId === option.id} - {@const isCompatible = isModelCompatible(option)} - {@const isHighlighted = index === highlightedIndex} - {@const missingModalities = getMissingModalities(option)} - -
isCompatible && handleSelect(option.id)} - onmouseenter={() => (highlightedIndex = index)} - onkeydown={(e) => { - if (isCompatible && (e.key === 'Enter' || e.key === ' ')) { - e.preventDefault(); - handleSelect(option.id); - } - }} - > - {option.model} - - {#if missingModalities} - - {#if missingModalities.vision} - - - - - -

No vision support

-
-
- {/if} - {#if missingModalities.audio} - - - - - -

No audio support

-
-
- {/if} -
- {/if} - - {#if isLoading} - - - - - -

Loading model...

-
-
- {:else if isLoaded} - - - - - -

Unload model

-
-
- {:else} - - {/if} -
- {/each} - - - - + + {/if} {/if} diff --git a/tools/server/webui/src/lib/constants/settings-config.ts b/tools/server/webui/src/lib/constants/settings-config.ts index 3764a2856b..f9584d01d7 100644 --- a/tools/server/webui/src/lib/constants/settings-config.ts +++ b/tools/server/webui/src/lib/constants/settings-config.ts @@ -12,9 +12,12 @@ export const SETTING_CONFIG_DEFAULT: Record = showMessageStats: true, askForTitleConfirmation: false, pasteLongTextToFileLen: 2500, + copyTextAttachmentsAsPlainText: false, pdfAsImage: false, disableAutoScroll: false, renderUserContentAsMarkdown: false, + alwaysShowSidebarOnDesktop: false, + autoShowSidebarOnNewChat: true, autoMicOnEmpty: false, // make sure these default values are in sync with `common.h` samplers: 'top_k;typ_p;top_p;min_p;temperature', @@ -50,6 +53,8 @@ export const SETTING_CONFIG_INFO: Record = { 'Choose the color theme for the interface. You can choose between System (follows your device settings), Light, or Dark.', pasteLongTextToFileLen: 'On pasting long text, it will be converted to a file. You can control the file length by setting the value of this parameter. Value 0 means disable.', + copyTextAttachmentsAsPlainText: + 'When copying a message with text attachments, combine them into a single plain text string instead of a special format that can be pasted back as attachments.', samplers: 'The order at which samplers are applied, in simplified way. Default is "top_k;typ_p;top_p;min_p;temperature": top_k->typ_p->top_p->min_p->temperature', temperature: @@ -96,6 +101,10 @@ export const SETTING_CONFIG_INFO: Record = { disableAutoScroll: 'Disable automatic scrolling while messages stream so you can control the viewport position manually.', renderUserContentAsMarkdown: 'Render user messages using markdown formatting in the chat.', + alwaysShowSidebarOnDesktop: + 'Always keep the sidebar visible on desktop instead of auto-hiding it.', + autoShowSidebarOnNewChat: + 'Automatically show sidebar when starting a new chat. Disable to keep the sidebar hidden until you click on it.', autoMicOnEmpty: 'Automatically show microphone button instead of send button when textarea is empty for models with audio modality support.', pyInterpreterEnabled: diff --git a/tools/server/webui/src/lib/stores/conversations.svelte.ts b/tools/server/webui/src/lib/stores/conversations.svelte.ts index f766561971..3300eb3113 100644 --- a/tools/server/webui/src/lib/stores/conversations.svelte.ts +++ b/tools/server/webui/src/lib/stores/conversations.svelte.ts @@ -385,8 +385,7 @@ class ConversationsStore { this.conversations = this.conversations.filter((c) => c.id !== convId); if (this.activeConversation?.id === convId) { - this.activeConversation = null; - this.activeMessages = []; + this.clearActiveConversation(); await goto(`?new_chat=true#/`); } } catch (error) { @@ -394,6 +393,29 @@ class ConversationsStore { } } + /** + * Deletes all conversations and their messages + */ + async deleteAll(): Promise { + try { + const allConversations = await DatabaseService.getAllConversations(); + + for (const conv of allConversations) { + await DatabaseService.deleteConversation(conv.id); + } + + this.clearActiveConversation(); + this.conversations = []; + + toast.success('All conversations deleted'); + + await goto(`?new_chat=true#/`); + } catch (error) { + console.error('Failed to delete all conversations:', error); + toast.error('Failed to delete conversations'); + } + } + // ───────────────────────────────────────────────────────────────────────────── // Import/Export // ───────────────────────────────────────────────────────────────────────────── diff --git a/tools/server/webui/src/lib/utils/clipboard.ts b/tools/server/webui/src/lib/utils/clipboard.ts new file mode 100644 index 0000000000..91e8ea75ae --- /dev/null +++ b/tools/server/webui/src/lib/utils/clipboard.ts @@ -0,0 +1,262 @@ +import { toast } from 'svelte-sonner'; +import { AttachmentType } from '$lib/enums'; +import type { + DatabaseMessageExtra, + DatabaseMessageExtraTextFile, + DatabaseMessageExtraLegacyContext +} from '$lib/types/database'; + +/** + * Copy text to clipboard with toast notification + * Uses modern clipboard API when available, falls back to legacy method for non-secure contexts + * @param text - Text to copy to clipboard + * @param successMessage - Custom success message (optional) + * @param errorMessage - Custom error message (optional) + * @returns Promise - True if successful, false otherwise + */ +export async function copyToClipboard( + text: string, + successMessage = 'Copied to clipboard', + errorMessage = 'Failed to copy to clipboard' +): Promise { + try { + // Try modern clipboard API first (secure contexts only) + if (navigator.clipboard && navigator.clipboard.writeText) { + await navigator.clipboard.writeText(text); + toast.success(successMessage); + return true; + } + + // Fallback for non-secure contexts + const textArea = document.createElement('textarea'); + textArea.value = text; + textArea.style.position = 'fixed'; + textArea.style.left = '-999999px'; + textArea.style.top = '-999999px'; + document.body.appendChild(textArea); + textArea.focus(); + textArea.select(); + + const successful = document.execCommand('copy'); + document.body.removeChild(textArea); + + if (successful) { + toast.success(successMessage); + return true; + } else { + throw new Error('execCommand failed'); + } + } catch (error) { + console.error('Failed to copy to clipboard:', error); + toast.error(errorMessage); + return false; + } +} + +/** + * Copy code with HTML entity decoding and toast notification + * @param rawCode - Raw code string that may contain HTML entities + * @param successMessage - Custom success message (optional) + * @param errorMessage - Custom error message (optional) + * @returns Promise - True if successful, false otherwise + */ +export async function copyCodeToClipboard( + rawCode: string, + successMessage = 'Code copied to clipboard', + errorMessage = 'Failed to copy code' +): Promise { + const doc = new DOMParser().parseFromString(rawCode, 'text/html'); + const decodedCode = doc.body.textContent ?? rawCode; + + return copyToClipboard(decodedCode, successMessage, errorMessage); +} + +/** + * Format for text attachments when copied to clipboard + */ +export interface ClipboardTextAttachment { + type: typeof AttachmentType.TEXT; + name: string; + content: string; +} + +/** + * Parsed result from clipboard content + */ +export interface ParsedClipboardContent { + message: string; + textAttachments: ClipboardTextAttachment[]; +} + +/** + * Formats a message with text attachments for clipboard copying. + * + * Default format (asPlainText = false): + * ``` + * "Text message content" + * [ + * {"type":"TEXT","name":"filename.txt","content":"..."}, + * {"type":"TEXT","name":"another.txt","content":"..."} + * ] + * ``` + * + * Plain text format (asPlainText = true): + * ``` + * Text message content + * + * file content here + * + * another file content + * ``` + * + * @param content - The message text content + * @param extras - Optional array of message attachments + * @param asPlainText - If true, format as plain text without JSON structure + * @returns Formatted string for clipboard + */ +export function formatMessageForClipboard( + content: string, + extras?: DatabaseMessageExtra[], + asPlainText: boolean = false +): string { + // Filter only text attachments (TEXT type and legacy CONTEXT type) + const textAttachments = + extras?.filter( + (extra): extra is DatabaseMessageExtraTextFile | DatabaseMessageExtraLegacyContext => + extra.type === AttachmentType.TEXT || extra.type === AttachmentType.LEGACY_CONTEXT + ) ?? []; + + if (textAttachments.length === 0) { + return content; + } + + if (asPlainText) { + const parts = [content]; + for (const att of textAttachments) { + parts.push(att.content); + } + return parts.join('\n\n'); + } + + const clipboardAttachments: ClipboardTextAttachment[] = textAttachments.map((att) => ({ + type: AttachmentType.TEXT, + name: att.name, + content: att.content + })); + + return `${JSON.stringify(content)}\n${JSON.stringify(clipboardAttachments, null, 2)}`; +} + +/** + * Parses clipboard content to extract message and text attachments. + * Supports both plain text and the special format with attachments. + * + * @param clipboardText - Raw text from clipboard + * @returns Parsed content with message and attachments + */ +export function parseClipboardContent(clipboardText: string): ParsedClipboardContent { + const defaultResult: ParsedClipboardContent = { + message: clipboardText, + textAttachments: [] + }; + + if (!clipboardText.startsWith('"')) { + return defaultResult; + } + + try { + let stringEndIndex = -1; + let escaped = false; + + for (let i = 1; i < clipboardText.length; i++) { + const char = clipboardText[i]; + + if (escaped) { + escaped = false; + continue; + } + + if (char === '\\') { + escaped = true; + continue; + } + + if (char === '"') { + stringEndIndex = i; + break; + } + } + + if (stringEndIndex === -1) { + return defaultResult; + } + + const jsonStringPart = clipboardText.substring(0, stringEndIndex + 1); + const remainingPart = clipboardText.substring(stringEndIndex + 1).trim(); + + const message = JSON.parse(jsonStringPart) as string; + + if (!remainingPart || !remainingPart.startsWith('[')) { + return { + message, + textAttachments: [] + }; + } + + const attachments = JSON.parse(remainingPart) as unknown[]; + + const validAttachments: ClipboardTextAttachment[] = []; + + for (const att of attachments) { + if (isValidTextAttachment(att)) { + validAttachments.push({ + type: AttachmentType.TEXT, + name: att.name, + content: att.content + }); + } + } + + return { + message, + textAttachments: validAttachments + }; + } catch { + return defaultResult; + } +} + +/** + * Type guard to validate a text attachment object + * @param obj The object to validate + * @returns true if the object is a valid text attachment + */ +function isValidTextAttachment( + obj: unknown +): obj is { type: string; name: string; content: string } { + if (typeof obj !== 'object' || obj === null) { + return false; + } + + const record = obj as Record; + + return ( + (record.type === AttachmentType.TEXT || record.type === 'TEXT') && + typeof record.name === 'string' && + typeof record.content === 'string' + ); +} + +/** + * Checks if clipboard content contains our special format with attachments + * @param clipboardText - Raw text from clipboard + * @returns true if the clipboard content contains our special format with attachments + */ +export function hasClipboardAttachments(clipboardText: string): boolean { + if (!clipboardText.startsWith('"')) { + return false; + } + + const parsed = parseClipboardContent(clipboardText); + return parsed.textAttachments.length > 0; +} diff --git a/tools/server/webui/src/lib/utils/copy.ts b/tools/server/webui/src/lib/utils/copy.ts deleted file mode 100644 index 16a4bbd45d..0000000000 --- a/tools/server/webui/src/lib/utils/copy.ts +++ /dev/null @@ -1,71 +0,0 @@ -import { toast } from 'svelte-sonner'; - -/** - * Copy text to clipboard with toast notification - * Uses modern clipboard API when available, falls back to legacy method for non-secure contexts - * @param text - Text to copy to clipboard - * @param successMessage - Custom success message (optional) - * @param errorMessage - Custom error message (optional) - * @returns Promise - True if successful, false otherwise - */ -export async function copyToClipboard( - text: string, - successMessage = 'Copied to clipboard', - errorMessage = 'Failed to copy to clipboard' -): Promise { - try { - // Try modern clipboard API first (secure contexts only) - if (navigator.clipboard && navigator.clipboard.writeText) { - await navigator.clipboard.writeText(text); - toast.success(successMessage); - return true; - } - - // Fallback for non-secure contexts - const textArea = document.createElement('textarea'); - textArea.value = text; - textArea.style.position = 'fixed'; - textArea.style.left = '-999999px'; - textArea.style.top = '-999999px'; - document.body.appendChild(textArea); - textArea.focus(); - textArea.select(); - - const successful = document.execCommand('copy'); - document.body.removeChild(textArea); - - if (successful) { - toast.success(successMessage); - return true; - } else { - throw new Error('execCommand failed'); - } - } catch (error) { - console.error('Failed to copy to clipboard:', error); - toast.error(errorMessage); - return false; - } -} - -/** - * Copy code with HTML entity decoding and toast notification - * @param rawCode - Raw code string that may contain HTML entities - * @param successMessage - Custom success message (optional) - * @param errorMessage - Custom error message (optional) - * @returns Promise - True if successful, false otherwise - */ -export async function copyCodeToClipboard( - rawCode: string, - successMessage = 'Code copied to clipboard', - errorMessage = 'Failed to copy code' -): Promise { - // Decode HTML entities - const decodedCode = rawCode - .replace(/&/g, '&') - .replace(/</g, '<') - .replace(/>/g, '>') - .replace(/"/g, '"') - .replace(/'/g, "'"); - - return copyToClipboard(decodedCode, successMessage, errorMessage); -} diff --git a/tools/server/webui/src/lib/utils/file-preview.ts b/tools/server/webui/src/lib/utils/file-preview.ts index 115f8727a9..26a60533ae 100644 --- a/tools/server/webui/src/lib/utils/file-preview.ts +++ b/tools/server/webui/src/lib/utils/file-preview.ts @@ -34,12 +34,3 @@ export function getFileTypeLabel(input: string | undefined): string { // Handle AttachmentType or other plain strings return input.toUpperCase(); } - -/** - * Truncates text content for preview display - * @param content - The text content to truncate - * @returns Truncated content with ellipsis if needed - */ -export function getPreviewText(content: string): string { - return content.length > 150 ? content.substring(0, 150) + '...' : content; -} diff --git a/tools/server/webui/src/lib/utils/index.ts b/tools/server/webui/src/lib/utils/index.ts index d8a893ed64..ab60061991 100644 --- a/tools/server/webui/src/lib/utils/index.ts +++ b/tools/server/webui/src/lib/utils/index.ts @@ -40,10 +40,19 @@ export { setConfigValue, getConfigValue, configToParameterRecord } from './confi export { createMessageCountMap, getMessageCount } from './conversation-utils'; // Clipboard utilities -export { copyToClipboard, copyCodeToClipboard } from './copy'; +export { + copyToClipboard, + copyCodeToClipboard, + formatMessageForClipboard, + parseClipboardContent, + hasClipboardAttachments, + type ClipboardTextAttachment, + type ParsedClipboardContent +} from './clipboard'; // File preview utilities -export { getFileTypeLabel, getPreviewText } from './file-preview'; +export { getFileTypeLabel } from './file-preview'; +export { getPreviewText } from './text'; // File type utilities export { diff --git a/tools/server/webui/src/lib/utils/text.ts b/tools/server/webui/src/lib/utils/text.ts new file mode 100644 index 0000000000..5c5dd0fe8c --- /dev/null +++ b/tools/server/webui/src/lib/utils/text.ts @@ -0,0 +1,7 @@ +/** + * Returns a shortened preview of the provided content capped at the given length. + * Appends an ellipsis when the content exceeds the maximum. + */ +export function getPreviewText(content: string, max = 150): string { + return content.length > max ? content.slice(0, max) + '...' : content; +} diff --git a/tools/server/webui/src/routes/+layout.svelte b/tools/server/webui/src/routes/+layout.svelte index 27dfac19c3..17e13e9f33 100644 --- a/tools/server/webui/src/routes/+layout.svelte +++ b/tools/server/webui/src/routes/+layout.svelte @@ -14,6 +14,7 @@ import { goto } from '$app/navigation'; import { modelsStore } from '$lib/stores/models.svelte'; import { TOOLTIP_DELAY_DURATION } from '$lib/constants/tooltip-config'; + import { IsMobile } from '$lib/hooks/is-mobile.svelte'; let { children } = $props(); @@ -21,6 +22,10 @@ let isHomeRoute = $derived(page.route.id === '/'); let isNewChatMode = $derived(page.url.searchParams.get('new_chat') === 'true'); let showSidebarByDefault = $derived(activeMessages().length > 0 || isLoading()); + let alwaysShowSidebarOnDesktop = $derived(config().alwaysShowSidebarOnDesktop); + let autoShowSidebarOnNewChat = $derived(config().autoShowSidebarOnNewChat); + let isMobile = new IsMobile(); + let isDesktop = $derived(!isMobile.current); let sidebarOpen = $state(false); let innerHeight = $state(); let chatSidebar: @@ -76,6 +81,11 @@ } $effect(() => { + if (alwaysShowSidebarOnDesktop && isDesktop) { + sidebarOpen = true; + return; + } + if (isHomeRoute && !isNewChatMode) { // Auto-collapse sidebar when navigating to home route (but not in new chat mode) sidebarOpen = false; @@ -83,8 +93,11 @@ // Keep sidebar open in new chat mode sidebarOpen = true; } else if (isChatRoute) { - // On chat routes, show sidebar by default - sidebarOpen = true; + // On chat routes, only auto-show sidebar if setting is enabled + if (autoShowSidebarOnNewChat) { + sidebarOpen = true; + } + // If setting is disabled, don't change sidebar state - let user control it manually } else { // Other routes follow default behavior sidebarOpen = showSidebarByDefault; @@ -190,12 +203,14 @@ - + {#if !(alwaysShowSidebarOnDesktop && isDesktop)} + + {/if} {@render children?.()} diff --git a/tools/server/webui/tests/server/demo.spec.ts b/tools/server/webui/tests/server/demo.spec.ts deleted file mode 100644 index e07cbbd725..0000000000 --- a/tools/server/webui/tests/server/demo.spec.ts +++ /dev/null @@ -1,7 +0,0 @@ -import { describe, it, expect } from 'vitest'; - -describe('sum test', () => { - it('adds 1 + 2 to equal 3', () => { - expect(1 + 2).toBe(3); - }); -}); diff --git a/tools/server/webui/tests/unit/clipboard.test.ts b/tools/server/webui/tests/unit/clipboard.test.ts new file mode 100644 index 0000000000..d8ea4899e2 --- /dev/null +++ b/tools/server/webui/tests/unit/clipboard.test.ts @@ -0,0 +1,423 @@ +import { describe, it, expect } from 'vitest'; +import { AttachmentType } from '$lib/enums'; +import { + formatMessageForClipboard, + parseClipboardContent, + hasClipboardAttachments +} from '$lib/utils/clipboard'; + +describe('formatMessageForClipboard', () => { + it('returns plain content when no extras', () => { + const result = formatMessageForClipboard('Hello world', undefined); + expect(result).toBe('Hello world'); + }); + + it('returns plain content when extras is empty array', () => { + const result = formatMessageForClipboard('Hello world', []); + expect(result).toBe('Hello world'); + }); + + it('handles empty string content', () => { + const result = formatMessageForClipboard('', undefined); + expect(result).toBe(''); + }); + + it('returns plain content when extras has only non-text attachments', () => { + const extras = [ + { + type: AttachmentType.IMAGE as const, + name: 'image.png', + base64Url: 'data:image/png;base64,...' + } + ]; + const result = formatMessageForClipboard('Hello world', extras); + expect(result).toBe('Hello world'); + }); + + it('filters non-text attachments and keeps only text ones', () => { + const extras = [ + { + type: AttachmentType.IMAGE as const, + name: 'image.png', + base64Url: 'data:image/png;base64,...' + }, + { + type: AttachmentType.TEXT as const, + name: 'file.txt', + content: 'Text content' + }, + { + type: AttachmentType.PDF as const, + name: 'doc.pdf', + base64Data: 'data:application/pdf;base64,...', + content: 'PDF content', + processedAsImages: false + } + ]; + const result = formatMessageForClipboard('Hello', extras); + + expect(result).toContain('"file.txt"'); + expect(result).not.toContain('image.png'); + expect(result).not.toContain('doc.pdf'); + }); + + it('formats message with text attachments', () => { + const extras = [ + { + type: AttachmentType.TEXT as const, + name: 'file1.txt', + content: 'File 1 content' + }, + { + type: AttachmentType.TEXT as const, + name: 'file2.txt', + content: 'File 2 content' + } + ]; + const result = formatMessageForClipboard('Hello world', extras); + + expect(result).toContain('"Hello world"'); + expect(result).toContain('"type": "TEXT"'); + expect(result).toContain('"name": "file1.txt"'); + expect(result).toContain('"content": "File 1 content"'); + expect(result).toContain('"name": "file2.txt"'); + }); + + it('handles content with quotes and special characters', () => { + const content = 'Hello "world" with\nnewline'; + const extras = [ + { + type: AttachmentType.TEXT as const, + name: 'test.txt', + content: 'Test content' + } + ]; + const result = formatMessageForClipboard(content, extras); + + // Should be valid JSON + expect(result.startsWith('"')).toBe(true); + // The content should be properly escaped + const parsed = JSON.parse(result.split('\n')[0]); + expect(parsed).toBe(content); + }); + + it('converts legacy context type to TEXT type', () => { + const extras = [ + { + type: AttachmentType.LEGACY_CONTEXT as const, + name: 'legacy.txt', + content: 'Legacy content' + } + ]; + const result = formatMessageForClipboard('Hello', extras); + + expect(result).toContain('"type": "TEXT"'); + expect(result).not.toContain('"context"'); + }); + + it('handles attachment content with special characters', () => { + const extras = [ + { + type: AttachmentType.TEXT as const, + name: 'code.js', + content: 'const x = "hello\\nworld";\nconst y = `template ${var}`;' + } + ]; + const formatted = formatMessageForClipboard('Check this code', extras); + const parsed = parseClipboardContent(formatted); + + expect(parsed.textAttachments[0].content).toBe( + 'const x = "hello\\nworld";\nconst y = `template ${var}`;' + ); + }); + + it('handles unicode characters in content and attachments', () => { + const extras = [ + { + type: AttachmentType.TEXT as const, + name: 'unicode.txt', + content: '日本語テスト 🎉 émojis' + } + ]; + const formatted = formatMessageForClipboard('Привет мир 👋', extras); + const parsed = parseClipboardContent(formatted); + + expect(parsed.message).toBe('Привет мир 👋'); + expect(parsed.textAttachments[0].content).toBe('日本語テスト 🎉 émojis'); + }); + + it('formats as plain text when asPlainText is true', () => { + const extras = [ + { + type: AttachmentType.TEXT as const, + name: 'file1.txt', + content: 'File 1 content' + }, + { + type: AttachmentType.TEXT as const, + name: 'file2.txt', + content: 'File 2 content' + } + ]; + const result = formatMessageForClipboard('Hello world', extras, true); + + expect(result).toBe('Hello world\n\nFile 1 content\n\nFile 2 content'); + }); + + it('returns plain content when asPlainText is true but no attachments', () => { + const result = formatMessageForClipboard('Hello world', [], true); + expect(result).toBe('Hello world'); + }); + + it('plain text mode does not use JSON format', () => { + const extras = [ + { + type: AttachmentType.TEXT as const, + name: 'test.txt', + content: 'Test content' + } + ]; + const result = formatMessageForClipboard('Hello', extras, true); + + expect(result).not.toContain('"type"'); + expect(result).not.toContain('['); + expect(result).toBe('Hello\n\nTest content'); + }); +}); + +describe('parseClipboardContent', () => { + it('returns plain text as message when not in special format', () => { + const result = parseClipboardContent('Hello world'); + + expect(result.message).toBe('Hello world'); + expect(result.textAttachments).toHaveLength(0); + }); + + it('handles empty string input', () => { + const result = parseClipboardContent(''); + + expect(result.message).toBe(''); + expect(result.textAttachments).toHaveLength(0); + }); + + it('handles whitespace-only input', () => { + const result = parseClipboardContent(' \n\t '); + + expect(result.message).toBe(' \n\t '); + expect(result.textAttachments).toHaveLength(0); + }); + + it('returns plain text as message when starts with quote but invalid format', () => { + const result = parseClipboardContent('"Unclosed quote'); + + expect(result.message).toBe('"Unclosed quote'); + expect(result.textAttachments).toHaveLength(0); + }); + + it('returns original text when JSON array is malformed', () => { + const input = '"Hello"\n[invalid json'; + + const result = parseClipboardContent(input); + + expect(result.message).toBe('"Hello"\n[invalid json'); + expect(result.textAttachments).toHaveLength(0); + }); + + it('parses message with text attachments', () => { + const input = `"Hello world" +[ + {"type":"TEXT","name":"file1.txt","content":"File 1 content"}, + {"type":"TEXT","name":"file2.txt","content":"File 2 content"} +]`; + + const result = parseClipboardContent(input); + + expect(result.message).toBe('Hello world'); + expect(result.textAttachments).toHaveLength(2); + expect(result.textAttachments[0].name).toBe('file1.txt'); + expect(result.textAttachments[0].content).toBe('File 1 content'); + expect(result.textAttachments[1].name).toBe('file2.txt'); + expect(result.textAttachments[1].content).toBe('File 2 content'); + }); + + it('handles escaped quotes in message', () => { + const input = `"Hello \\"world\\" with quotes" +[ + {"type":"TEXT","name":"file.txt","content":"test"} +]`; + + const result = parseClipboardContent(input); + + expect(result.message).toBe('Hello "world" with quotes'); + expect(result.textAttachments).toHaveLength(1); + }); + + it('handles newlines in message', () => { + const input = `"Hello\\nworld" +[ + {"type":"TEXT","name":"file.txt","content":"test"} +]`; + + const result = parseClipboardContent(input); + + expect(result.message).toBe('Hello\nworld'); + expect(result.textAttachments).toHaveLength(1); + }); + + it('returns message only when no array follows', () => { + const input = '"Just a quoted string"'; + + const result = parseClipboardContent(input); + + expect(result.message).toBe('Just a quoted string'); + expect(result.textAttachments).toHaveLength(0); + }); + + it('filters out invalid attachment objects', () => { + const input = `"Hello" +[ + {"type":"TEXT","name":"valid.txt","content":"valid"}, + {"type":"INVALID","name":"invalid.txt","content":"invalid"}, + {"name":"missing-type.txt","content":"missing"}, + {"type":"TEXT","content":"missing name"} +]`; + + const result = parseClipboardContent(input); + + expect(result.message).toBe('Hello'); + expect(result.textAttachments).toHaveLength(1); + expect(result.textAttachments[0].name).toBe('valid.txt'); + }); + + it('handles empty attachments array', () => { + const input = '"Hello"\n[]'; + + const result = parseClipboardContent(input); + + expect(result.message).toBe('Hello'); + expect(result.textAttachments).toHaveLength(0); + }); + + it('roundtrips correctly with formatMessageForClipboard', () => { + const originalContent = 'Hello "world" with\nspecial characters'; + const originalExtras = [ + { + type: AttachmentType.TEXT as const, + name: 'file1.txt', + content: 'Content with\nnewlines and "quotes"' + }, + { + type: AttachmentType.TEXT as const, + name: 'file2.txt', + content: 'Another file' + } + ]; + + const formatted = formatMessageForClipboard(originalContent, originalExtras); + const parsed = parseClipboardContent(formatted); + + expect(parsed.message).toBe(originalContent); + expect(parsed.textAttachments).toHaveLength(2); + expect(parsed.textAttachments[0].name).toBe('file1.txt'); + expect(parsed.textAttachments[0].content).toBe('Content with\nnewlines and "quotes"'); + expect(parsed.textAttachments[1].name).toBe('file2.txt'); + expect(parsed.textAttachments[1].content).toBe('Another file'); + }); +}); + +describe('hasClipboardAttachments', () => { + it('returns false for plain text', () => { + expect(hasClipboardAttachments('Hello world')).toBe(false); + }); + + it('returns false for empty string', () => { + expect(hasClipboardAttachments('')).toBe(false); + }); + + it('returns false for quoted string without attachments', () => { + expect(hasClipboardAttachments('"Hello world"')).toBe(false); + }); + + it('returns true for valid format with attachments', () => { + const input = `"Hello" +[{"type":"TEXT","name":"file.txt","content":"test"}]`; + + expect(hasClipboardAttachments(input)).toBe(true); + }); + + it('returns false for format with empty attachments array', () => { + const input = '"Hello"\n[]'; + + expect(hasClipboardAttachments(input)).toBe(false); + }); + + it('returns false for malformed JSON', () => { + expect(hasClipboardAttachments('"Hello"\n[broken')).toBe(false); + }); +}); + +describe('roundtrip edge cases', () => { + it('preserves empty message with attachments', () => { + const extras = [ + { + type: AttachmentType.TEXT as const, + name: 'file.txt', + content: 'Content only' + } + ]; + const formatted = formatMessageForClipboard('', extras); + const parsed = parseClipboardContent(formatted); + + expect(parsed.message).toBe(''); + expect(parsed.textAttachments).toHaveLength(1); + expect(parsed.textAttachments[0].content).toBe('Content only'); + }); + + it('preserves attachment with empty content', () => { + const extras = [ + { + type: AttachmentType.TEXT as const, + name: 'empty.txt', + content: '' + } + ]; + const formatted = formatMessageForClipboard('Message', extras); + const parsed = parseClipboardContent(formatted); + + expect(parsed.message).toBe('Message'); + expect(parsed.textAttachments).toHaveLength(1); + expect(parsed.textAttachments[0].content).toBe(''); + }); + + it('preserves multiple backslashes', () => { + const content = 'Path: C:\\\\Users\\\\test\\\\file.txt'; + const extras = [ + { + type: AttachmentType.TEXT as const, + name: 'path.txt', + content: 'D:\\\\Data\\\\file' + } + ]; + const formatted = formatMessageForClipboard(content, extras); + const parsed = parseClipboardContent(formatted); + + expect(parsed.message).toBe(content); + expect(parsed.textAttachments[0].content).toBe('D:\\\\Data\\\\file'); + }); + + it('preserves tabs and various whitespace', () => { + const content = 'Line1\t\tTabbed\n Spaced\r\nCRLF'; + const extras = [ + { + type: AttachmentType.TEXT as const, + name: 'whitespace.txt', + content: '\t\t\n\n ' + } + ]; + const formatted = formatMessageForClipboard(content, extras); + const parsed = parseClipboardContent(formatted); + + expect(parsed.message).toBe(content); + expect(parsed.textAttachments[0].content).toBe('\t\t\n\n '); + }); +}); diff --git a/tools/server/webui/src/lib/utils/latex-protection.test.ts b/tools/server/webui/tests/unit/latex-protection.test.ts similarity index 99% rename from tools/server/webui/src/lib/utils/latex-protection.test.ts rename to tools/server/webui/tests/unit/latex-protection.test.ts index 40fe1b0db2..84328dbc17 100644 --- a/tools/server/webui/src/lib/utils/latex-protection.test.ts +++ b/tools/server/webui/tests/unit/latex-protection.test.ts @@ -1,6 +1,6 @@ /* eslint-disable no-irregular-whitespace */ import { describe, it, expect, test } from 'vitest'; -import { maskInlineLaTeX, preprocessLaTeX } from './latex-protection'; +import { maskInlineLaTeX, preprocessLaTeX } from '$lib/utils/latex-protection'; describe('maskInlineLaTeX', () => { it('should protect LaTeX $x + y$ but not money $3.99', () => { diff --git a/tools/server/webui/src/lib/utils/model-names.test.ts b/tools/server/webui/tests/unit/model-names.test.ts similarity index 95% rename from tools/server/webui/src/lib/utils/model-names.test.ts rename to tools/server/webui/tests/unit/model-names.test.ts index ca85df3d30..40c5a0e3aa 100644 --- a/tools/server/webui/src/lib/utils/model-names.test.ts +++ b/tools/server/webui/tests/unit/model-names.test.ts @@ -1,5 +1,5 @@ import { describe, expect, it } from 'vitest'; -import { isValidModelName, normalizeModelName } from './model-names'; +import { isValidModelName, normalizeModelName } from '$lib/utils/model-names'; describe('normalizeModelName', () => { it('preserves Hugging Face org/model format (single slash)', () => { diff --git a/tools/server/webui/vite.config.ts b/tools/server/webui/vite.config.ts index b41d3511b4..5183c09fca 100644 --- a/tools/server/webui/vite.config.ts +++ b/tools/server/webui/vite.config.ts @@ -125,9 +125,9 @@ export default defineConfig({ { extends: './vite.config.ts', test: { - name: 'server', + name: 'unit', environment: 'node', - include: ['tests/server/**/*.{test,spec}.{js,ts}'] + include: ['tests/unit/**/*.{test,spec}.{js,ts}'] } }, {