From fd1085ffb71dabf4490b428512d0fd6fdc637536 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Sat, 13 Dec 2025 08:34:26 +0100 Subject: [PATCH 01/13] model-conversion : use CONVERTED_MODEL value for converted model [no ci] (#17984) * model-conversion : use CONVERTED_MODEL value for converted model [no ci] This commit updates the model verification scripts to use the CONVERTED_MODEL environment variable instead of using the MODEL_PATH (the original model path) as the basis for the converted model file name. The motivation for this that currently if the converted model file name differs from the original model directory/name the verification scripts will look for the wrong .bin files that were generating when running the models. For example, the following steps were not possible: ```console (venv) $ huggingface-cli download google/gemma-3-270m-it --local-dir ggml-org/gemma-3-270m (venv) $ python3 convert_hf_to_gguf.py ggml-org/gemma-3-270m --outfile test-bf16.gguf --outtype bf16 (venv) $ cd examples/model-conversion/ (venv) $ export MODEL_PATH=../../ggml-org/gemma-3-270m (venv) $ export CONVERTED_MODEL=../../test-bf16.gguf (venv) $ make causal-verify-logits ... Data saved to data/llamacpp-test-bf16.bin Data saved to data/llamacpp-test-bf16.txt Error: llama.cpp logits file not found: data/llamacpp-gemma-3-270m.bin Please run scripts/run-converted-model.sh first to generate this file. make: *** [Makefile:62: causal-verify-logits] Error 1 ``` With the changes in this commit, the above steps will now work as expected. --- .../scripts/causal/compare-logits.py | 24 ++++++++----------- .../scripts/utils/__init__.py | 0 .../scripts/utils/check-nmse.py | 7 ++++-- .../model-conversion/scripts/utils/common.py | 20 ++++++++++++++++ pyrightconfig.json | 2 +- 5 files changed, 36 insertions(+), 17 deletions(-) create mode 100644 examples/model-conversion/scripts/utils/__init__.py create mode 100644 examples/model-conversion/scripts/utils/common.py diff --git a/examples/model-conversion/scripts/causal/compare-logits.py b/examples/model-conversion/scripts/causal/compare-logits.py index 2744789099..894302c69e 100755 --- a/examples/model-conversion/scripts/causal/compare-logits.py +++ b/examples/model-conversion/scripts/causal/compare-logits.py @@ -1,10 +1,13 @@ #!/usr/bin/env python3 -import numpy as np import sys -import os +import numpy as np from pathlib import Path +# Add utils directory to path for direct script execution +sys.path.insert(0, str(Path(__file__).parent.parent / "utils")) +from common import get_model_name_from_env_path # type: ignore[import-not-found] + def quick_logits_check(pytorch_file, llamacpp_file): """Lightweight sanity check before NMSE""" @@ -35,20 +38,13 @@ def quick_logits_check(pytorch_file, llamacpp_file): return True def main(): - model_path = os.getenv('MODEL_PATH') - if not model_path: - print("Error: MODEL_PATH environment variable not set") - sys.exit(1) - - if not os.path.exists(model_path): - print(f"Error: Model file not found: {model_path}") - sys.exit(1) - - model_name = os.path.basename(model_path) + model_name = get_model_name_from_env_path('MODEL_PATH') data_dir = Path("data") - pytorch_file = data_dir / f"pytorch-{model_name}.bin" - llamacpp_file = data_dir / f"llamacpp-{model_name}.bin" + + llamacpp_model_name = get_model_name_from_env_path('CONVERTED_MODEL') + print(f"Using converted model: {llamacpp_model_name}") + llamacpp_file = data_dir / f"llamacpp-{llamacpp_model_name}.bin" if not pytorch_file.exists(): print(f"Error: PyTorch logits file not found: {pytorch_file}") diff --git a/examples/model-conversion/scripts/utils/__init__.py b/examples/model-conversion/scripts/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/model-conversion/scripts/utils/check-nmse.py b/examples/model-conversion/scripts/utils/check-nmse.py index 939e3153cc..83f63f9ff3 100755 --- a/examples/model-conversion/scripts/utils/check-nmse.py +++ b/examples/model-conversion/scripts/utils/check-nmse.py @@ -5,6 +5,7 @@ import sys import os import argparse from pathlib import Path +from common import get_model_name_from_env_path # type: ignore[import-not-found] def calculate_nmse(reference, test): mse = np.mean((test - reference) ** 2) @@ -67,11 +68,13 @@ def main(): parser.add_argument('-m', '--model-path', required=True, help='Path to the model directory') args = parser.parse_args() - model_name = os.path.basename(args.model_path) + model_name = get_model_name_from_env_path('MODEL_PATH') data_dir = Path("data") pytorch_file = data_dir / f"pytorch-{model_name}.bin" - llamacpp_file = data_dir / f"llamacpp-{model_name}.bin" + + llamacpp_model_name = get_model_name_from_env_path('CONVERTED_MODEL') + llamacpp_file = data_dir / f"llamacpp-{llamacpp_model_name}.bin" print(f"Model name: {model_name}") print(f"PyTorch logits file: {pytorch_file}") diff --git a/examples/model-conversion/scripts/utils/common.py b/examples/model-conversion/scripts/utils/common.py new file mode 100644 index 0000000000..945f9a1a1d --- /dev/null +++ b/examples/model-conversion/scripts/utils/common.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 + +import os +import sys + +def get_model_name_from_env_path(env_path_name): + model_path = os.getenv(env_path_name) + if not model_path: + print(f"Error: {env_path_name} environment variable not set") + sys.exit(1) + + if not os.path.exists(model_path): + print(f"Error: Model file not found: {model_path}") + sys.exit(1) + + name = os.path.basename(os.path.normpath(model_path)) + if name.endswith(".gguf"): + name = name[:-5] + + return name diff --git a/pyrightconfig.json b/pyrightconfig.json index 5320fe5864..a7bc007bdc 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -1,5 +1,5 @@ { - "extraPaths": ["gguf-py"], + "extraPaths": ["gguf-py", "examples/model-conversion/scripts"], "pythonVersion": "3.9", "pythonPlatform": "All", "reportUnusedImport": "warning", From 2bc94e792867497e2f0088b17d3fd3fcebd1f44b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sat, 13 Dec 2025 08:35:50 +0100 Subject: [PATCH 02/13] add llama-completion to completion-bash executables (#17976) --- common/arg.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/common/arg.cpp b/common/arg.cpp index 16cb2e03a6..788a9ab4e6 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -639,6 +639,7 @@ static void common_params_print_completion(common_params_context & ctx_arg) { "llama-batched-bench", "llama-bench", "llama-cli", + "llama-completion", "llama-convert-llama2c-to-ggml", "llama-cvector-generator", "llama-embedding", From 07a10c1090502163ab48cc14a5dfc7059949a2c9 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 13 Dec 2025 01:40:04 -0600 Subject: [PATCH 03/13] vulkan: Allow non-pow2 n_experts in topk_moe (#17872) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 22 ++++++++++++------- .../ggml-vulkan/vulkan-shaders/topk_moe.comp | 12 +++++++--- tests/test-backend-ops.cpp | 4 ++++ 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c6f5809ccd..52e7e1e7fe 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -757,7 +757,8 @@ struct vk_device_struct { vk_pipeline pipeline_flash_attn_split_k_reduce; - vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT]; + // [2] is for whether to take n_experts from spec constant (0) or push constant (1) + vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2]; std::vector all_pipelines; @@ -1149,6 +1150,7 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256); struct vk_op_topk_moe_push_constants { uint32_t n_rows; + uint32_t n_experts_push; uint32_t n_expert_used; float clamp_min; float clamp_max; @@ -4204,10 +4206,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); - for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + for (uint32_t use_push = 0; use_push < 2; ++use_push) { + for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX][use_push], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM][use_push], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX][use_push], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + } } for (auto &c : compiles) { @@ -8554,7 +8558,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); GGML_ASSERT(idx < num_topk_moe_pipelines); topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); - return ctx->device->pipeline_topk_moe[idx][mode]; + // use n_experts from push constant if it's not equal to the power of two spec constant + bool use_push = dst->ne[0] != (1u << idx); + return ctx->device->pipeline_topk_moe[idx][mode][use_push]; } if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { @@ -10158,6 +10164,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, vk_op_topk_moe_push_constants pc {}; pc.n_rows = n_rows; + pc.n_experts_push = n_experts; pc.n_expert_used = n_expert_used; if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) { ggml_tensor * clamp = cgraph->nodes[node_idx + 7]; @@ -12832,8 +12839,7 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc } const int n_expert = softmax->ne[0]; - // n_expert must be a power of 2 - if (!is_pow2(n_expert) || n_expert > (1 << (num_topk_moe_pipelines-1))) { + if (n_expert > (1 << (num_topk_moe_pipelines-1))) { return false; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp index 5cd0785d20..b83a2b9d2d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp @@ -10,6 +10,7 @@ layout (push_constant) uniform parameter { uint n_rows; + uint n_experts_push; uint n_expert_used; float clamp_min; float clamp_max; @@ -18,11 +19,16 @@ layout (push_constant) uniform parameter layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; layout(constant_id = 0) const uint WARP_SIZE = 32; -layout(constant_id = 1) const uint n_experts = 512; +layout(constant_id = 1) const uint n_experts_spec = 512; layout(constant_id = 2) const bool with_norm = true; layout(constant_id = 3) const bool late_softmax = false; +layout(constant_id = 4) const bool nexperts_use_push = false; -const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1; +uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec; + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE); layout (binding = 0, std430) readonly buffer Logits {float logits[];}; layout (binding = 1, std430) writeonly buffer Weights {float weights[];}; @@ -94,7 +100,7 @@ void main() { } if (!late_softmax) { - softmax_warp_inplace(wt, n_experts, lane, false); + softmax_warp_inplace(wt, n_experts, lane, nexperts_use_push); } // at this point, each thread holds a portion of softmax, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 308e752b1d..f1bed864e3 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7971,8 +7971,12 @@ static std::vector> make_test_cases_eval() { for (bool with_norm : {false, true}) { test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm)); + test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm)); test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm)); + test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm)); + test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm)); test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm)); + test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm)); } test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true)); From 8e4d678528b2582eb8a9f773309e88b6a167e4cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sat, 13 Dec 2025 08:40:50 +0100 Subject: [PATCH 04/13] common : skip model validation when --completion-bash is requested (#17975) --- common/arg.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/arg.cpp b/common/arg.cpp index 788a9ab4e6..19f22f883f 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -504,7 +504,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context // model is required (except for server) // TODO @ngxson : maybe show a list of available models in CLI in this case - if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage) { + if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage && !params.completion) { throw std::invalid_argument("error: --model is required\n"); } From 3c6391e748d8c00c45fba033811508288580cdc7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 13 Dec 2025 09:48:34 +0200 Subject: [PATCH 05/13] speculative-simple : free batch on exit (#17985) --- examples/speculative-simple/speculative-simple.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index a8e53f28eb..0d11d0f803 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -255,6 +255,8 @@ int main(int argc, char ** argv) { LOG_INF("target:\n\n"); common_perf_print(ctx_tgt, smpl); + llama_batch_free(batch_tgt); + common_sampler_free(smpl); common_speculative_free(spec); From 303f8615e94d74f140eb4c3947758c1eca933c3a Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 13 Dec 2025 03:04:29 -0600 Subject: [PATCH 06/13] vulkan: Multi-pass softmax for large number of cols (#17892) When the number of cols is large, split each row across multiple workgroups. There are three phases that communicate partial results through temp buffers: (1) compute max partials (2) take max of partials, compute sum(exp(x-max)) partials (3) sum partials, compute scaled result --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 64 ++++++++++++++- .../vulkan-shaders/soft_max_large1.comp | 62 +++++++++++++++ .../vulkan-shaders/soft_max_large2.comp | 79 +++++++++++++++++++ .../vulkan-shaders/soft_max_large3.comp | 65 +++++++++++++++ .../vulkan-shaders/soft_max_large_common.glsl | 53 +++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 7 ++ tests/test-backend-ops.cpp | 3 + 7 files changed, 331 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 52e7e1e7fe..0b4f8c36ad 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -722,6 +722,11 @@ struct vk_device_struct { vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; vk_pipeline pipeline_soft_max_back_f32; + + vk_pipeline pipeline_soft_max_large1_f32, pipeline_soft_max_large1_f32_f16; + vk_pipeline pipeline_soft_max_large2_f32, pipeline_soft_max_large2_f32_f16; + vk_pipeline pipeline_soft_max_large3_f32, pipeline_soft_max_large3_f32_f16; + vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16; vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16; vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; @@ -3998,6 +4003,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32, "soft_max_large1_f32", soft_max_large1_f32_len, soft_max_large1_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32, "soft_max_large2_f32", soft_max_large2_f32_len, soft_max_large2_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32, "soft_max_large3_f32", soft_max_large3_f32_len, soft_max_large3_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32_f16, "soft_max_large1_f32_f16", soft_max_large1_f32_f16_len, soft_max_large1_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32_f16, "soft_max_large2_f32_f16", soft_max_large2_f32_f16_len, soft_max_large2_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32_f16, "soft_max_large3_f32_f16", soft_max_large3_f32_f16_len, soft_max_large3_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); @@ -10117,7 +10129,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, { + vk_op_soft_max_push_constants pc { ncols, src1 != nullptr ? nrows_y : (uint32_t)0, (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], @@ -10128,7 +10140,55 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, n_head_log2, nrows_x, src2 != nullptr - }); + }; + + if (ncols <= 16384) { + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, std::move(pc)); + } else { + + vk_subbuffer buf_a = ggml_vk_tensor_subbuffer(ctx, src0); + vk_subbuffer buf_b = src1 ? ggml_vk_tensor_subbuffer(ctx, src1) : buf_a; + vk_subbuffer buf_c = src2 ? ggml_vk_tensor_subbuffer(ctx, src2) : buf_a; + vk_subbuffer buf_d = ggml_vk_tensor_subbuffer(ctx, dst); + + uint32_t elems_per_wg = 128 * 4; + uint32_t num_wgs = CEIL_DIV(ncols, elems_per_wg); + size_t tmp_size = num_wgs * nrows_x * sizeof(float); + + if (ctx->prealloc_size_x < tmp_size) { + ctx->prealloc_size_x = tmp_size; + ggml_vk_preallocate_buffers(ctx, subctx); + } + if (ctx->prealloc_size_y < tmp_size) { + ctx->prealloc_size_y = tmp_size; + ggml_vk_preallocate_buffers(ctx, subctx); + } + if (ctx->prealloc_x_need_sync || ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + + vk_subbuffer buf_x = { ctx->prealloc_x, 0, tmp_size }; + vk_subbuffer buf_y = { ctx->prealloc_y, 0, tmp_size }; + + std::array elements = { num_wgs, nrows_x, 1 }; + + vk_pipeline pipeline1 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large1_f32_f16 : ctx->device->pipeline_soft_max_large1_f32; + vk_pipeline pipeline2 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large2_f32_f16 : ctx->device->pipeline_soft_max_large2_f32; + vk_pipeline pipeline3 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large3_f32_f16 : ctx->device->pipeline_soft_max_large3_f32; + + ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline3, 1); + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements); + ggml_vk_sync_buffers(ctx, subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements); + ggml_vk_sync_buffers(ctx, subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline3, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements); + + ctx->prealloc_x_need_sync = true; + ctx->prealloc_y_need_sync = true; + } } static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp new file mode 100644 index 0000000000..39c4663912 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp @@ -0,0 +1,62 @@ +#version 450 + +#include "soft_max_large_common.glsl" + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint rowx = gl_WorkGroupID.y; + const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } + + if (rowx >= p.nrows_x) { + return; + } + + float slope = get_slope(rowx); + + // Find max + FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02]; + + [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + FLOAT_TYPE a = FLOAT_TYPE(0); + if (col < p.KX) { + a = data_a[rowx * p.KX + col]; + } + + FLOAT_TYPE b = FLOAT_TYPE(0); + if (p.KY > 0 && col < p.KX) { + b = data_b[rowy_start + col]; + } + + FLOAT_TYPE v = a * p.scale + slope * b; + + if (col < p.KX) { + max_val = max(max_val, v); + } + } + + // reduce across the workgroup + vals[tid] = max_val; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(vals[tid], vals[tid + s]); + } + barrier(); + } + + if (tid == 0) { + max_val = vals[0]; + data_m[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = max_val; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp new file mode 100644 index 0000000000..69524f5f75 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp @@ -0,0 +1,79 @@ +#version 450 + +#include "soft_max_large_common.glsl" + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint rowx = gl_WorkGroupID.y; + const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } + + if (rowx >= p.nrows_x) { + return; + } + + float slope = get_slope(rowx); + + // Find max + FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02]; + + [[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) { + if (i + tid < gl_NumWorkGroups.x) { + max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]); + } + } + + // reduce across the workgroup + vals[tid] = max_val; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(max_val, vals[tid + s]); + } + barrier(); + } + + max_val = vals[0]; + barrier(); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); + + // Compute sum{exp(x - max)} + [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + break; + } + + // compute exp(a*scale+b*slope), add it to sum + const uint i = rowx * p.KX + col; + FLOAT_TYPE val; + val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val); + sum += val; + data_d[i] = D_TYPE(val); + } + + // reduce across the workgroup + vals[tid] = sum; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] += vals[tid + s]; + } + barrier(); + } + + if (tid == 0) { + sum = vals[0]; + data_s[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = sum; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp new file mode 100644 index 0000000000..06efd7d9fb --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp @@ -0,0 +1,65 @@ +#version 450 + +#include "soft_max_large_common.glsl" + +shared FLOAT_TYPE sumsh[BLOCK_SIZE]; + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint rowx = gl_WorkGroupID.y; + const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } + + if (rowx >= p.nrows_x) { + return; + } + + FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02]; + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); + + [[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) { + if (i + tid < gl_NumWorkGroups.x) { + max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]); + sum += data_s[rowx * gl_NumWorkGroups.x + i + tid]; + } + } + + // reduce across the workgroup + vals[tid] = max_val; + sumsh[tid] = sum; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(max_val, vals[tid + s]); + sumsh[tid] += sumsh[tid + s]; + } + barrier(); + } + + max_val = vals[0]; + sum = sumsh[0]; + + if (p.has_sinks != 0) { + sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val)); + } + + FLOAT_TYPE rcpdivisor = 1.0/sum; + + [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + continue; + } + + data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl new file mode 100644 index 0000000000..6636d1f8de --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl @@ -0,0 +1,53 @@ +#extension GL_EXT_control_flow_attributes : enable + +layout (push_constant) uniform parameter +{ + uint KX; + uint KY; + uint ne00; + uint ne01; + uint ne02; + uint ne12; + uint ne13; + uint nb11; + uint nb12; + uint nb13; + float scale; + float max_bias; + float m0; + float m1; + uint n_head_log2; + uint nrows_x; + uint has_sinks; +} p; + +#include "types.glsl" + +layout(constant_id = 0) const uint BLOCK_SIZE = 128; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +layout(constant_id = 1) const uint num_iters = 4; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) readonly buffer Z {float data_c[];}; +layout (binding = 3) buffer D {D_TYPE data_d[];}; +layout (binding = 4) buffer M {float data_m[];}; +layout (binding = 5) buffer S {float data_s[];}; + +shared FLOAT_TYPE vals[BLOCK_SIZE]; + +float get_slope(uint rowx) { + float slope = 1.0f; + + // ALiBi + if (p.max_bias > 0.0f) { + const uint h = (rowx / p.ne01) % p.ne02; // head index + + const float base = h < p.n_head_log2 ? p.m0 : p.m1; + const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; + + slope = pow(base, exp); + } + + return slope; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 92bae088b2..72c63e8173 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -899,6 +899,13 @@ void process_shaders() { string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large1_f32", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large2_f32", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large3_f32", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large1_f32_f16", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large2_f32_f16", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large3_f32_f16", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f1bed864e3..416218b5b8 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7652,6 +7652,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f)); + for (float max_bias : {0.0f, 8.0f}) { for (float scale : {1.0f, 0.1f}) { for (int64_t ne0 : {16, 1024}) { From 3229a23fa675ab7316fc53d045ded99f90fc6766 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 13 Dec 2025 03:07:49 -0600 Subject: [PATCH 07/13] vulkan: support GGML_OP_DIAG (#17893) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 24 +++++++++++++++ ggml/src/ggml-vulkan/vulkan-shaders/diag.comp | 29 +++++++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 ++ 3 files changed, 55 insertions(+) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/diag.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 0b4f8c36ad..ac3841764e 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -659,6 +659,7 @@ struct vk_device_struct { vk_pipeline pipeline_cos_f32; vk_pipeline pipeline_log[2]; vk_pipeline pipeline_tri[2]; + vk_pipeline pipeline_diag[2]; vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; @@ -3924,6 +3925,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1); @@ -8416,6 +8420,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16]; } return nullptr; + case GGML_OP_DIAG: + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_diag[dst->type == GGML_TYPE_F16]; + } + return nullptr; case GGML_OP_CLAMP: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_clamp_f32; @@ -9109,6 +9119,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_COS: case GGML_OP_LOG: case GGML_OP_TRI: + case GGML_OP_DIAG: case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_ROLL: @@ -9796,6 +9807,12 @@ static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p)); } +static void ggml_vk_diag(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG, std::move(p)); +} + static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); p.param1 = ggml_get_op_params_f32(dst, 0); @@ -11924,6 +11941,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_TRI: ggml_vk_tri(ctx, compute_ctx, src0, node); + break; + case GGML_OP_DIAG: + ggml_vk_diag(ctx, compute_ctx, src0, node); + break; case GGML_OP_CLAMP: ggml_vk_clamp(ctx, compute_ctx, src0, node); @@ -14067,6 +14088,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_LOG: case GGML_OP_TRI: + case GGML_OP_DIAG: return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && op->type == op->src[0]->type; case GGML_OP_ARGSORT: @@ -14657,6 +14679,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_log(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_TRI) { tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0)); + } else if (tensor->op == GGML_OP_DIAG) { + tensor_clone = ggml_diag(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_CLAMP) { const float * params = (const float *)tensor->op_params; tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp b/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp new file mode 100644 index 0000000000..cd3f42f491 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp @@ -0,0 +1,29 @@ +#version 450 + +#include "rte.glsl" +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + + if (i10 == i11) { + const float val = float(data_a[get_aoffset() + i13*p.nb03 + i12*p.nb02 + 0*p.nb01 + i10*p.nb00]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val); + } else { + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 72c63e8173..f606ea1085 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -854,6 +854,8 @@ void process_shaders() { string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("diag_f16", "diag.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("diag_f32", "diag.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); From 36255a22681583c3787bce778132c89d74da4e4c Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 13 Dec 2025 03:12:53 -0600 Subject: [PATCH 08/13] vulkan: support get_rows for i32 (#17941) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 7 +++++++ ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp | 4 ++-- .../src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp | 8 +++++--- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index ac3841764e..34ec09d403 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3738,6 +3738,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32], "get_rows_i32", get_rows_i32_len, get_rows_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -8294,6 +8295,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const switch (op) { case GGML_OP_GET_ROWS: GGML_ASSERT(src1->type == GGML_TYPE_I32); + if (src0->type == GGML_TYPE_I32) { + // i32 src only supports i32 result + GGML_ASSERT(dst->type == GGML_TYPE_I32); + return ctx->device->pipeline_get_rows[src0->type]; + } if (dst->type == GGML_TYPE_F16) { return ctx->device->pipeline_get_rows[src0->type]; } @@ -13964,6 +13970,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_I32: return true; default: return false; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp index 76d83041ce..e88bdd057e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp @@ -26,9 +26,9 @@ void main() { const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23; #if defined(DATA_A_BF16) - FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00])); + TEMP_TYPE v = TEMP_TYPE(bf16_to_fp32(data_a[a_offset + i00])); #else - FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]); + TEMP_TYPE v = TEMP_TYPE(data_a[a_offset + i00]); #endif #ifndef OPTIMIZATION_ERROR_WORKAROUND data_d[d_offset + i00] = D_TYPE(v); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index f606ea1085..b0ade078c7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -704,13 +704,15 @@ void process_shaders() { shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp"; if (tname == "f16") { - string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}})); + string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{"TEMP_TYPE", "FLOAT_TYPE"}, {data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}})); } else { - string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}})); + string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{"TEMP_TYPE", "FLOAT_TYPE"}, {data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}})); } - string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}})); + string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{"TEMP_TYPE", "FLOAT_TYPE"}, {data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}})); } + string_to_spv("get_rows_i32", "get_rows.comp", {{"TEMP_TYPE", "uint"}, {"A_TYPE", "uint"}, {"B_TYPE", "int"}, {"D_TYPE", "uint"}}); + string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); From 66ba51252ed1cf08739fbed8a538d3cb3541655d Mon Sep 17 00:00:00 2001 From: Gustavo Rocha Dias <91472747+gustrd@users.noreply.github.com> Date: Sat, 13 Dec 2025 08:46:36 -0300 Subject: [PATCH 09/13] cmake: correct scope - link ws2_32 for MinGW/w64devkit builds in cpp-httplib (#17972) * fix - w64devkit build * fix - w64devkit build private scope --- vendor/cpp-httplib/CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vendor/cpp-httplib/CMakeLists.txt b/vendor/cpp-httplib/CMakeLists.txt index e90e8e2d1b..8f0d15d1fd 100644 --- a/vendor/cpp-httplib/CMakeLists.txt +++ b/vendor/cpp-httplib/CMakeLists.txt @@ -11,8 +11,9 @@ endif() target_link_libraries (${TARGET} PRIVATE Threads::Threads) if (WIN32 AND NOT MSVC) - target_link_libraries(${TARGET} PUBLIC ws2_32) + target_link_libraries(${TARGET} PRIVATE ws2_32) endif() + target_compile_features(${TARGET} PRIVATE cxx_std_17) target_compile_definitions(${TARGET} PRIVATE From 4d5ae24c0ac79c4e360773bac58dd2c2a46b7f67 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Sat, 13 Dec 2025 12:53:37 +0100 Subject: [PATCH 10/13] arg: fix common_params_parse not accepting negated arg (#17991) --- common/arg.cpp | 5 ++++- common/arg.h | 2 +- tests/test-arg-parser.cpp | 4 ++++ tools/server/server-models.cpp | 2 +- 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 19f22f883f..bb2a6840ba 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -724,7 +724,7 @@ static void add_rpc_devices(const std::string & servers) { } } -bool common_params_parse(int argc, char ** argv, llama_example ex, std::map & out_map) { +bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map & out_map) { common_params dummy_params; common_params_context ctx_arg = common_params_parser_init(dummy_params, ex, nullptr); @@ -733,6 +733,9 @@ bool common_params_parse(int argc, char ** argv, llama_example ex, std::map & out_map); +bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map & out_map); // initialize argument parser context - used by test-arg-parser and preset common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp index 90750b20c2..468d325e22 100644 --- a/tests/test-arg-parser.cpp +++ b/tests/test-arg-parser.cpp @@ -72,6 +72,10 @@ int main(void) { argv = {"binary_name", "--draft", "123"}; assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_EMBEDDING)); + // negated arg + argv = {"binary_name", "--no-mmap"}; + assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON)); + printf("test-arg-parser: test valid usage\n\n"); diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 6c618a673c..6be5ffbdf0 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -171,7 +171,7 @@ server_presets::server_presets(int argc, char ** argv, common_params & base_para } // read base args from router's argv - common_params_parse(argc, argv, LLAMA_EXAMPLE_SERVER, base_args); + common_params_to_map(argc, argv, LLAMA_EXAMPLE_SERVER, base_args); // remove any router-controlled args from base_args for (const auto & cargs : control_args) { From 5266379bcae74214af397f36aa81b2a08b15d545 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 13 Dec 2025 09:19:51 -0600 Subject: [PATCH 11/13] llama_context: synchronize before reallocating output buffer (#17974) --- src/llama-context.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 2692297dca..9914b3276b 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1318,6 +1318,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark) LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); #endif + synchronize(); buf_output = nullptr; logits = nullptr; embd = nullptr; From 4ed2bae50d64dcff7f99cb2b28f737fda314abf5 Mon Sep 17 00:00:00 2001 From: Sergey Fedorov Date: Sun, 14 Dec 2025 05:02:43 +0800 Subject: [PATCH 12/13] server-models.cpp: add missing (#18000) Fixes: https://github.com/ggml-org/llama.cpp/issues/17999 --- tools/server/server-models.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 6be5ffbdf0..3690c0bb82 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #ifdef _WIN32 #include From c00ff929dcfd150234e62f30e863bca4f1337aee Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Sat, 13 Dec 2025 22:33:29 +0100 Subject: [PATCH 13/13] scripts: add script to compare logprobs of llama.cpp against other frameworks (#17947) * scripts: add script to compare logits of llama.cpp against other frameworks * accept custom prompt file * fix code style * clarify endpoint * fix displaying * use abs for diff * fix vllm case * rm output file * rename to compare-logprobs * add "pattern" --- scripts/compare-logprobs.py | 281 ++++++++++++++++++++++++++++++++++++ 1 file changed, 281 insertions(+) create mode 100644 scripts/compare-logprobs.py diff --git a/scripts/compare-logprobs.py b/scripts/compare-logprobs.py new file mode 100644 index 0000000000..63861dd9a4 --- /dev/null +++ b/scripts/compare-logprobs.py @@ -0,0 +1,281 @@ +import argparse +import requests +import json +from pathlib import Path +import logging + +logger = logging.getLogger("compare-logprobs") +logging.basicConfig(level=logging.INFO) + + +DESCRIPTION = """ +Compare logits between llama.cpp and another inference engine using OpenAI-compatible server endpoints. + +Unlike compare-logits.py, it allows dumping logits from a hosted API endpoint. Useful when it's not possible to run both models locally. + +Example usage: + Step 1: Dump logits from two different servers + python scripts/compare-logprobs.py dump logits_llama.log http://localhost:8080/v1/completions + python scripts/compare-logprobs.py dump logits_other.log http://other-engine:8000/v1/completions + + (optionally, you can add --api-key if the endpoint requires authentication) + + Step 2: Compare the dumped logits + python scripts/compare-logprobs.py compare logits_llama.log logits_other.log report.md +""" + + +def generate_input_prompt(length: int) -> list[str]: + CORPUS = """ + You are an advanced AI assistant capable of using tools to gather information, perform calculations, or execute tasks. Always think step by step before responding. If a user's query requires external data, computation, or actions beyond your internal knowledge, use the appropriate tools via function calls. + + ### Tool Call Format: + When you need to use a tool, output the call in this exact XML format. Include the opening and closing tags. Do not escape arguments; they will be parsed as plain text. + + You can make multiple calls in one go by placing them one after another. + """ + words = [w.strip() for w in CORPUS.strip().split(" ")] + words = [w for w in words if len(w) > 0] # filter out empty strings + while len(words) < length: + words += words + return words[:length] + + +def dump_logits( + endpoint: str, + output_path: Path, + input_words: list[str], + pattern: list[tuple[bool, int]], + api_key=None, +): + logger.info(f"Dumping logits to {output_path} from endpoint {endpoint}...") + words = input_words + curr_text = "" + n_total = sum(n for get, n in pattern if get) + n_done = 0 + i_cur = 0 + i_total = len(words) + with output_path.open("w") as f: + for get, n in pattern: + if not get: + # skip n words + for i in range(n): + curr_text += words.pop(0) + " " + i_cur += 1 + continue + # get n words + for i in range(n): + curr_text += words.pop(0) + " " + payload = { + "prompt": curr_text.strip(), + "temperature": 0.0, + "top_k": 1, + "max_tokens": 1, + "logprobs": 1, + "stream": False, + } + response = requests.post( + endpoint, + json=payload, + headers={"Authorization": f"Bearer {api_key}"} if api_key else {}, + ) + response.raise_for_status() + data = response.json() + data["__index"] = i_cur # add index for easier debugging later + data = json.dumps(data) + f.write(f"{data}\n") + n_done += 1 + i_cur += 1 + logger.info( + f"\n\n{data}\n\n[Step: {n_done}/{n_total} | Word: {i_cur}/{i_total}]" + ) + logger.info(f"Logits dumped to {output_path}") + + +def get_token_logprobs(data: dict): + logprobs = data["choices"][0]["logprobs"] + if "content" in logprobs: + # llama.cpp case + top = logprobs["content"][0]["top_logprobs"][0] + return top["token"], top["logprob"] + else: + # vllm case + tokens = logprobs["tokens"] + token_logprobs = logprobs["token_logprobs"] + return tokens[0], token_logprobs[0] + + +def clean_text(text: str) -> str: + return ( + "'" + + text.replace("\n", "\\n") + .replace("\t", "\\t") + .replace("\r", "\\r") + .replace("|", "\\|") + + "'" + ) + + +def compare_logits(input1: Path, input2: Path, output_path: Path): + with input1.open("r") as f1, input2.open("r") as f2, output_path.open("w") as fout: + lines1 = f1.readlines() + lines2 = f2.readlines() + + tab_header = [ + "idx", + input1.name, + "logprob_1", + input2.name, + "logprob_2", + "diff (abs)", + ] + tab_entries = [] + tab_max_widths = [len(h) for h in tab_header] + + assert len(lines1) == len( + lines2 + ), "Input files must have the same number of lines." + + fout.write("# Logits Comparison Report\n\n") + for i, (line1, line2) in enumerate(zip(lines1, lines2)): + if not line1.strip() or not line2.strip(): + continue # skip empty lines + + data1 = json.loads(line1) + data2 = json.loads(line2) + + idx1 = data1.get("__index", -1) + idx2 = data2.get("__index", -1) + if idx1 != idx2: + logger.warning( + f"Warning: Mismatched indices at line {i}: {idx1} vs {idx2}" + ) + + token1, logprob1 = get_token_logprobs(data1) + token2, logprob2 = get_token_logprobs(data2) + + token1 = clean_text(token1) + token2 = clean_text(token2) + abs_diff = abs(logprob1 - logprob2) + + tab_entries.append( + ( + str(idx1 + 1), + token1, + f"{logprob1:.4f}", + token2, + f"{logprob2:.4f}", + f"{(abs_diff):.4f}", + ) + ) + + for i in range(len(tab_entries)): + for j in range(len(tab_header)): + tab_max_widths[j] = max(tab_max_widths[j], len(tab_entries[i][j])) + + output = "" + for j in range(len(tab_header)): + output += f"| {tab_header[j]:<{tab_max_widths[j]}} " + output += "|\n" + for j in range(len(tab_header)): + output += f"|{'-' * (tab_max_widths[j] + 2)}" + output += "|\n" + for entry in tab_entries: + for j in range(len(tab_header)): + output += f"| {entry[j]:<{tab_max_widths[j]}} " + output += "|\n" + + logger.info("\n" + output) + fout.write(output) + logger.info(f"Report written to {output_path}") + + +def parse_pattern(pattern: str) -> list[tuple[bool, int]]: + parts = pattern.split(",") + result = [] + for i, part in enumerate(parts): + n = int(part) + if i % 2 == 0: + result.append((True, n)) # get n words + else: + result.append((False, n)) # skip n words + return result + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=DESCRIPTION, formatter_class=argparse.RawTextHelpFormatter + ) + subparsers = parser.add_subparsers( + dest="verb", required=True, help="action to perform" + ) + + # dump subcommand + parser_dump = subparsers.add_parser("dump", help="dump logits from an endpoint") + parser_dump.add_argument( + "output", type=Path, help="output path for dumped logits (.log)" + ) + parser_dump.add_argument( + "endpoint", type=str, help="OAI-compat /completions endpoint" + ) + parser_dump.add_argument( + "--api-key", + type=str, + default=None, + help="API key for authentication (if required)", + ) + parser_dump.add_argument( + "--file", + type=Path, + default=None, + help="File containing prompt to use instead of the default", + ) + parser_dump.add_argument( + "--pattern", + type=str, + default="10,1000,10,4000,10", + help="Pattern n_get,n_skip,... where n_get is number of words to get and n_skip is number of words to skip (num of words, NOT num of tokens)", + ) + + # compare subcommand + parser_compare = subparsers.add_parser( + "compare", help="compare two dumped logits files" + ) + parser_compare.add_argument("input1", type=Path, help="first input file (.log)") + parser_compare.add_argument("input2", type=Path, help="second input file (.log)") + parser_compare.add_argument( + "output", type=Path, help="output path for comparison report (.md)" + ) + + try: + return parser.parse_args() + except Exception as e: + parser.print_help() + raise e + + +def main(): + args = parse_args() + + if args.verb == "dump": + pattern = parse_pattern(args.pattern) + input_length = sum(n for _, n in pattern) + input_words = generate_input_prompt(input_length) + if args.file is not None: + with args.file.open("r") as f: + input_words = f.read().strip().split(" ") + if input_length < sum(n for _, n in pattern): + raise ValueError( + f"Input file has only {input_length} words, but pattern requires at least {input_length} words." + ) + input_length = len(input_words) + logger.info(f"Using {input_length} words") + dump_logits(args.endpoint, args.output, input_words, pattern, args.api_key) + elif args.verb == "compare": + compare_logits(args.input1, args.input2, args.output) + else: + raise ValueError(f"Unknown verb: {args.verb}") + + +if __name__ == "__main__": + main()