diff --git a/common/arg.cpp b/common/arg.cpp index be61ac8e20..0226a6e644 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -504,7 +504,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context // model is required (except for server) // TODO @ngxson : maybe show a list of available models in CLI in this case - if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage) { + if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage && !params.completion) { throw std::invalid_argument("error: --model is required\n"); } @@ -639,6 +639,7 @@ static void common_params_print_completion(common_params_context & ctx_arg) { "llama-batched-bench", "llama-bench", "llama-cli", + "llama-completion", "llama-convert-llama2c-to-ggml", "llama-cvector-generator", "llama-embedding", @@ -723,7 +724,7 @@ static void add_rpc_devices(const std::string & servers) { } } -bool common_params_parse(int argc, char ** argv, llama_example ex, std::map & out_map) { +bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map & out_map) { common_params dummy_params; common_params_context ctx_arg = common_params_parser_init(dummy_params, ex, nullptr); @@ -732,6 +733,9 @@ bool common_params_parse(int argc, char ** argv, llama_example ex, std::map & out_map); +bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map & out_map); // initialize argument parser context - used by test-arg-parser and preset common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); diff --git a/examples/model-conversion/scripts/causal/compare-logits.py b/examples/model-conversion/scripts/causal/compare-logits.py index 2744789099..894302c69e 100755 --- a/examples/model-conversion/scripts/causal/compare-logits.py +++ b/examples/model-conversion/scripts/causal/compare-logits.py @@ -1,10 +1,13 @@ #!/usr/bin/env python3 -import numpy as np import sys -import os +import numpy as np from pathlib import Path +# Add utils directory to path for direct script execution +sys.path.insert(0, str(Path(__file__).parent.parent / "utils")) +from common import get_model_name_from_env_path # type: ignore[import-not-found] + def quick_logits_check(pytorch_file, llamacpp_file): """Lightweight sanity check before NMSE""" @@ -35,20 +38,13 @@ def quick_logits_check(pytorch_file, llamacpp_file): return True def main(): - model_path = os.getenv('MODEL_PATH') - if not model_path: - print("Error: MODEL_PATH environment variable not set") - sys.exit(1) - - if not os.path.exists(model_path): - print(f"Error: Model file not found: {model_path}") - sys.exit(1) - - model_name = os.path.basename(model_path) + model_name = get_model_name_from_env_path('MODEL_PATH') data_dir = Path("data") - pytorch_file = data_dir / f"pytorch-{model_name}.bin" - llamacpp_file = data_dir / f"llamacpp-{model_name}.bin" + + llamacpp_model_name = get_model_name_from_env_path('CONVERTED_MODEL') + print(f"Using converted model: {llamacpp_model_name}") + llamacpp_file = data_dir / f"llamacpp-{llamacpp_model_name}.bin" if not pytorch_file.exists(): print(f"Error: PyTorch logits file not found: {pytorch_file}") diff --git a/examples/model-conversion/scripts/utils/__init__.py b/examples/model-conversion/scripts/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/model-conversion/scripts/utils/check-nmse.py b/examples/model-conversion/scripts/utils/check-nmse.py index 939e3153cc..83f63f9ff3 100755 --- a/examples/model-conversion/scripts/utils/check-nmse.py +++ b/examples/model-conversion/scripts/utils/check-nmse.py @@ -5,6 +5,7 @@ import sys import os import argparse from pathlib import Path +from common import get_model_name_from_env_path # type: ignore[import-not-found] def calculate_nmse(reference, test): mse = np.mean((test - reference) ** 2) @@ -67,11 +68,13 @@ def main(): parser.add_argument('-m', '--model-path', required=True, help='Path to the model directory') args = parser.parse_args() - model_name = os.path.basename(args.model_path) + model_name = get_model_name_from_env_path('MODEL_PATH') data_dir = Path("data") pytorch_file = data_dir / f"pytorch-{model_name}.bin" - llamacpp_file = data_dir / f"llamacpp-{model_name}.bin" + + llamacpp_model_name = get_model_name_from_env_path('CONVERTED_MODEL') + llamacpp_file = data_dir / f"llamacpp-{llamacpp_model_name}.bin" print(f"Model name: {model_name}") print(f"PyTorch logits file: {pytorch_file}") diff --git a/examples/model-conversion/scripts/utils/common.py b/examples/model-conversion/scripts/utils/common.py new file mode 100644 index 0000000000..945f9a1a1d --- /dev/null +++ b/examples/model-conversion/scripts/utils/common.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 + +import os +import sys + +def get_model_name_from_env_path(env_path_name): + model_path = os.getenv(env_path_name) + if not model_path: + print(f"Error: {env_path_name} environment variable not set") + sys.exit(1) + + if not os.path.exists(model_path): + print(f"Error: Model file not found: {model_path}") + sys.exit(1) + + name = os.path.basename(os.path.normpath(model_path)) + if name.endswith(".gguf"): + name = name[:-5] + + return name diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index a8e53f28eb..0d11d0f803 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -255,6 +255,8 @@ int main(int argc, char ** argv) { LOG_INF("target:\n\n"); common_perf_print(ctx_tgt, smpl); + llama_batch_free(batch_tgt); + common_sampler_free(smpl); common_speculative_free(spec); diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c6f5809ccd..34ec09d403 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -659,6 +659,7 @@ struct vk_device_struct { vk_pipeline pipeline_cos_f32; vk_pipeline pipeline_log[2]; vk_pipeline pipeline_tri[2]; + vk_pipeline pipeline_diag[2]; vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; @@ -722,6 +723,11 @@ struct vk_device_struct { vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; vk_pipeline pipeline_soft_max_back_f32; + + vk_pipeline pipeline_soft_max_large1_f32, pipeline_soft_max_large1_f32_f16; + vk_pipeline pipeline_soft_max_large2_f32, pipeline_soft_max_large2_f32_f16; + vk_pipeline pipeline_soft_max_large3_f32, pipeline_soft_max_large3_f32_f16; + vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16; vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16; vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; @@ -757,7 +763,8 @@ struct vk_device_struct { vk_pipeline pipeline_flash_attn_split_k_reduce; - vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT]; + // [2] is for whether to take n_experts from spec constant (0) or push constant (1) + vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2]; std::vector all_pipelines; @@ -1149,6 +1156,7 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256); struct vk_op_topk_moe_push_constants { uint32_t n_rows; + uint32_t n_experts_push; uint32_t n_expert_used; float clamp_min; float clamp_max; @@ -3730,6 +3738,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32], "get_rows_i32", get_rows_i32_len, get_rows_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -3917,6 +3926,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1); @@ -3996,6 +4008,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32, "soft_max_large1_f32", soft_max_large1_f32_len, soft_max_large1_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32, "soft_max_large2_f32", soft_max_large2_f32_len, soft_max_large2_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32, "soft_max_large3_f32", soft_max_large3_f32_len, soft_max_large3_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32_f16, "soft_max_large1_f32_f16", soft_max_large1_f32_f16_len, soft_max_large1_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32_f16, "soft_max_large2_f32_f16", soft_max_large2_f32_f16_len, soft_max_large2_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32_f16, "soft_max_large3_f32_f16", soft_max_large3_f32_f16_len, soft_max_large3_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); @@ -4204,10 +4223,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); - for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + for (uint32_t use_push = 0; use_push < 2; ++use_push) { + for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX][use_push], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM][use_push], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX][use_push], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + } } for (auto &c : compiles) { @@ -8274,6 +8295,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const switch (op) { case GGML_OP_GET_ROWS: GGML_ASSERT(src1->type == GGML_TYPE_I32); + if (src0->type == GGML_TYPE_I32) { + // i32 src only supports i32 result + GGML_ASSERT(dst->type == GGML_TYPE_I32); + return ctx->device->pipeline_get_rows[src0->type]; + } if (dst->type == GGML_TYPE_F16) { return ctx->device->pipeline_get_rows[src0->type]; } @@ -8400,6 +8426,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16]; } return nullptr; + case GGML_OP_DIAG: + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_diag[dst->type == GGML_TYPE_F16]; + } + return nullptr; case GGML_OP_CLAMP: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_clamp_f32; @@ -8554,7 +8586,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); GGML_ASSERT(idx < num_topk_moe_pipelines); topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); - return ctx->device->pipeline_topk_moe[idx][mode]; + // use n_experts from push constant if it's not equal to the power of two spec constant + bool use_push = dst->ne[0] != (1u << idx); + return ctx->device->pipeline_topk_moe[idx][mode][use_push]; } if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { @@ -9091,6 +9125,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_COS: case GGML_OP_LOG: case GGML_OP_TRI: + case GGML_OP_DIAG: case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_ROLL: @@ -9778,6 +9813,12 @@ static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p)); } +static void ggml_vk_diag(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG, std::move(p)); +} + static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); p.param1 = ggml_get_op_params_f32(dst, 0); @@ -10111,7 +10152,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, { + vk_op_soft_max_push_constants pc { ncols, src1 != nullptr ? nrows_y : (uint32_t)0, (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], @@ -10122,7 +10163,55 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, n_head_log2, nrows_x, src2 != nullptr - }); + }; + + if (ncols <= 16384) { + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, std::move(pc)); + } else { + + vk_subbuffer buf_a = ggml_vk_tensor_subbuffer(ctx, src0); + vk_subbuffer buf_b = src1 ? ggml_vk_tensor_subbuffer(ctx, src1) : buf_a; + vk_subbuffer buf_c = src2 ? ggml_vk_tensor_subbuffer(ctx, src2) : buf_a; + vk_subbuffer buf_d = ggml_vk_tensor_subbuffer(ctx, dst); + + uint32_t elems_per_wg = 128 * 4; + uint32_t num_wgs = CEIL_DIV(ncols, elems_per_wg); + size_t tmp_size = num_wgs * nrows_x * sizeof(float); + + if (ctx->prealloc_size_x < tmp_size) { + ctx->prealloc_size_x = tmp_size; + ggml_vk_preallocate_buffers(ctx, subctx); + } + if (ctx->prealloc_size_y < tmp_size) { + ctx->prealloc_size_y = tmp_size; + ggml_vk_preallocate_buffers(ctx, subctx); + } + if (ctx->prealloc_x_need_sync || ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + + vk_subbuffer buf_x = { ctx->prealloc_x, 0, tmp_size }; + vk_subbuffer buf_y = { ctx->prealloc_y, 0, tmp_size }; + + std::array elements = { num_wgs, nrows_x, 1 }; + + vk_pipeline pipeline1 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large1_f32_f16 : ctx->device->pipeline_soft_max_large1_f32; + vk_pipeline pipeline2 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large2_f32_f16 : ctx->device->pipeline_soft_max_large2_f32; + vk_pipeline pipeline3 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large3_f32_f16 : ctx->device->pipeline_soft_max_large3_f32; + + ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline3, 1); + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements); + ggml_vk_sync_buffers(ctx, subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements); + ggml_vk_sync_buffers(ctx, subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline3, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements); + + ctx->prealloc_x_need_sync = true; + ctx->prealloc_y_need_sync = true; + } } static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -10158,6 +10247,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, vk_op_topk_moe_push_constants pc {}; pc.n_rows = n_rows; + pc.n_experts_push = n_experts; pc.n_expert_used = n_expert_used; if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) { ggml_tensor * clamp = cgraph->nodes[node_idx + 7]; @@ -11857,6 +11947,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_TRI: ggml_vk_tri(ctx, compute_ctx, src0, node); + break; + case GGML_OP_DIAG: + ggml_vk_diag(ctx, compute_ctx, src0, node); + break; case GGML_OP_CLAMP: ggml_vk_clamp(ctx, compute_ctx, src0, node); @@ -12832,8 +12926,7 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc } const int n_expert = softmax->ne[0]; - // n_expert must be a power of 2 - if (!is_pow2(n_expert) || n_expert > (1 << (num_topk_moe_pipelines-1))) { + if (n_expert > (1 << (num_topk_moe_pipelines-1))) { return false; } @@ -13877,6 +13970,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_I32: return true; default: return false; @@ -14001,6 +14095,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_LOG: case GGML_OP_TRI: + case GGML_OP_DIAG: return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && op->type == op->src[0]->type; case GGML_OP_ARGSORT: @@ -14591,6 +14686,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_log(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_TRI) { tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0)); + } else if (tensor->op == GGML_OP_DIAG) { + tensor_clone = ggml_diag(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_CLAMP) { const float * params = (const float *)tensor->op_params; tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp b/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp new file mode 100644 index 0000000000..cd3f42f491 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp @@ -0,0 +1,29 @@ +#version 450 + +#include "rte.glsl" +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + + if (i10 == i11) { + const float val = float(data_a[get_aoffset() + i13*p.nb03 + i12*p.nb02 + 0*p.nb01 + i10*p.nb00]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val); + } else { + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp index 76d83041ce..e88bdd057e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp @@ -26,9 +26,9 @@ void main() { const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23; #if defined(DATA_A_BF16) - FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00])); + TEMP_TYPE v = TEMP_TYPE(bf16_to_fp32(data_a[a_offset + i00])); #else - FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]); + TEMP_TYPE v = TEMP_TYPE(data_a[a_offset + i00]); #endif #ifndef OPTIMIZATION_ERROR_WORKAROUND data_d[d_offset + i00] = D_TYPE(v); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp new file mode 100644 index 0000000000..39c4663912 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp @@ -0,0 +1,62 @@ +#version 450 + +#include "soft_max_large_common.glsl" + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint rowx = gl_WorkGroupID.y; + const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } + + if (rowx >= p.nrows_x) { + return; + } + + float slope = get_slope(rowx); + + // Find max + FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02]; + + [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + FLOAT_TYPE a = FLOAT_TYPE(0); + if (col < p.KX) { + a = data_a[rowx * p.KX + col]; + } + + FLOAT_TYPE b = FLOAT_TYPE(0); + if (p.KY > 0 && col < p.KX) { + b = data_b[rowy_start + col]; + } + + FLOAT_TYPE v = a * p.scale + slope * b; + + if (col < p.KX) { + max_val = max(max_val, v); + } + } + + // reduce across the workgroup + vals[tid] = max_val; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(vals[tid], vals[tid + s]); + } + barrier(); + } + + if (tid == 0) { + max_val = vals[0]; + data_m[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = max_val; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp new file mode 100644 index 0000000000..69524f5f75 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp @@ -0,0 +1,79 @@ +#version 450 + +#include "soft_max_large_common.glsl" + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint rowx = gl_WorkGroupID.y; + const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } + + if (rowx >= p.nrows_x) { + return; + } + + float slope = get_slope(rowx); + + // Find max + FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02]; + + [[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) { + if (i + tid < gl_NumWorkGroups.x) { + max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]); + } + } + + // reduce across the workgroup + vals[tid] = max_val; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(max_val, vals[tid + s]); + } + barrier(); + } + + max_val = vals[0]; + barrier(); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); + + // Compute sum{exp(x - max)} + [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + break; + } + + // compute exp(a*scale+b*slope), add it to sum + const uint i = rowx * p.KX + col; + FLOAT_TYPE val; + val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val); + sum += val; + data_d[i] = D_TYPE(val); + } + + // reduce across the workgroup + vals[tid] = sum; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] += vals[tid + s]; + } + barrier(); + } + + if (tid == 0) { + sum = vals[0]; + data_s[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = sum; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp new file mode 100644 index 0000000000..06efd7d9fb --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp @@ -0,0 +1,65 @@ +#version 450 + +#include "soft_max_large_common.glsl" + +shared FLOAT_TYPE sumsh[BLOCK_SIZE]; + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint rowx = gl_WorkGroupID.y; + const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } + + if (rowx >= p.nrows_x) { + return; + } + + FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02]; + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); + + [[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) { + if (i + tid < gl_NumWorkGroups.x) { + max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]); + sum += data_s[rowx * gl_NumWorkGroups.x + i + tid]; + } + } + + // reduce across the workgroup + vals[tid] = max_val; + sumsh[tid] = sum; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(max_val, vals[tid + s]); + sumsh[tid] += sumsh[tid + s]; + } + barrier(); + } + + max_val = vals[0]; + sum = sumsh[0]; + + if (p.has_sinks != 0) { + sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val)); + } + + FLOAT_TYPE rcpdivisor = 1.0/sum; + + [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + continue; + } + + data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl new file mode 100644 index 0000000000..6636d1f8de --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl @@ -0,0 +1,53 @@ +#extension GL_EXT_control_flow_attributes : enable + +layout (push_constant) uniform parameter +{ + uint KX; + uint KY; + uint ne00; + uint ne01; + uint ne02; + uint ne12; + uint ne13; + uint nb11; + uint nb12; + uint nb13; + float scale; + float max_bias; + float m0; + float m1; + uint n_head_log2; + uint nrows_x; + uint has_sinks; +} p; + +#include "types.glsl" + +layout(constant_id = 0) const uint BLOCK_SIZE = 128; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +layout(constant_id = 1) const uint num_iters = 4; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) readonly buffer Z {float data_c[];}; +layout (binding = 3) buffer D {D_TYPE data_d[];}; +layout (binding = 4) buffer M {float data_m[];}; +layout (binding = 5) buffer S {float data_s[];}; + +shared FLOAT_TYPE vals[BLOCK_SIZE]; + +float get_slope(uint rowx) { + float slope = 1.0f; + + // ALiBi + if (p.max_bias > 0.0f) { + const uint h = (rowx / p.ne01) % p.ne02; // head index + + const float base = h < p.n_head_log2 ? p.m0 : p.m1; + const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; + + slope = pow(base, exp); + } + + return slope; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp index 5cd0785d20..b83a2b9d2d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp @@ -10,6 +10,7 @@ layout (push_constant) uniform parameter { uint n_rows; + uint n_experts_push; uint n_expert_used; float clamp_min; float clamp_max; @@ -18,11 +19,16 @@ layout (push_constant) uniform parameter layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; layout(constant_id = 0) const uint WARP_SIZE = 32; -layout(constant_id = 1) const uint n_experts = 512; +layout(constant_id = 1) const uint n_experts_spec = 512; layout(constant_id = 2) const bool with_norm = true; layout(constant_id = 3) const bool late_softmax = false; +layout(constant_id = 4) const bool nexperts_use_push = false; -const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1; +uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec; + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE); layout (binding = 0, std430) readonly buffer Logits {float logits[];}; layout (binding = 1, std430) writeonly buffer Weights {float weights[];}; @@ -94,7 +100,7 @@ void main() { } if (!late_softmax) { - softmax_warp_inplace(wt, n_experts, lane, false); + softmax_warp_inplace(wt, n_experts, lane, nexperts_use_push); } // at this point, each thread holds a portion of softmax, diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 92bae088b2..b0ade078c7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -704,13 +704,15 @@ void process_shaders() { shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp"; if (tname == "f16") { - string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}})); + string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{"TEMP_TYPE", "FLOAT_TYPE"}, {data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}})); } else { - string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}})); + string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{"TEMP_TYPE", "FLOAT_TYPE"}, {data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}})); } - string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}})); + string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{"TEMP_TYPE", "FLOAT_TYPE"}, {data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}})); } + string_to_spv("get_rows_i32", "get_rows.comp", {{"TEMP_TYPE", "uint"}, {"A_TYPE", "uint"}, {"B_TYPE", "int"}, {"D_TYPE", "uint"}}); + string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); @@ -854,6 +856,8 @@ void process_shaders() { string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("diag_f16", "diag.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("diag_f32", "diag.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -899,6 +903,13 @@ void process_shaders() { string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large1_f32", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large2_f32", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large3_f32", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large1_f32_f16", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large2_f32_f16", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large3_f32_f16", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); diff --git a/pyrightconfig.json b/pyrightconfig.json index 5320fe5864..a7bc007bdc 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -1,5 +1,5 @@ { - "extraPaths": ["gguf-py"], + "extraPaths": ["gguf-py", "examples/model-conversion/scripts"], "pythonVersion": "3.9", "pythonPlatform": "All", "reportUnusedImport": "warning", diff --git a/scripts/compare-logprobs.py b/scripts/compare-logprobs.py new file mode 100644 index 0000000000..63861dd9a4 --- /dev/null +++ b/scripts/compare-logprobs.py @@ -0,0 +1,281 @@ +import argparse +import requests +import json +from pathlib import Path +import logging + +logger = logging.getLogger("compare-logprobs") +logging.basicConfig(level=logging.INFO) + + +DESCRIPTION = """ +Compare logits between llama.cpp and another inference engine using OpenAI-compatible server endpoints. + +Unlike compare-logits.py, it allows dumping logits from a hosted API endpoint. Useful when it's not possible to run both models locally. + +Example usage: + Step 1: Dump logits from two different servers + python scripts/compare-logprobs.py dump logits_llama.log http://localhost:8080/v1/completions + python scripts/compare-logprobs.py dump logits_other.log http://other-engine:8000/v1/completions + + (optionally, you can add --api-key if the endpoint requires authentication) + + Step 2: Compare the dumped logits + python scripts/compare-logprobs.py compare logits_llama.log logits_other.log report.md +""" + + +def generate_input_prompt(length: int) -> list[str]: + CORPUS = """ + You are an advanced AI assistant capable of using tools to gather information, perform calculations, or execute tasks. Always think step by step before responding. If a user's query requires external data, computation, or actions beyond your internal knowledge, use the appropriate tools via function calls. + + ### Tool Call Format: + When you need to use a tool, output the call in this exact XML format. Include the opening and closing tags. Do not escape arguments; they will be parsed as plain text. + + You can make multiple calls in one go by placing them one after another. + """ + words = [w.strip() for w in CORPUS.strip().split(" ")] + words = [w for w in words if len(w) > 0] # filter out empty strings + while len(words) < length: + words += words + return words[:length] + + +def dump_logits( + endpoint: str, + output_path: Path, + input_words: list[str], + pattern: list[tuple[bool, int]], + api_key=None, +): + logger.info(f"Dumping logits to {output_path} from endpoint {endpoint}...") + words = input_words + curr_text = "" + n_total = sum(n for get, n in pattern if get) + n_done = 0 + i_cur = 0 + i_total = len(words) + with output_path.open("w") as f: + for get, n in pattern: + if not get: + # skip n words + for i in range(n): + curr_text += words.pop(0) + " " + i_cur += 1 + continue + # get n words + for i in range(n): + curr_text += words.pop(0) + " " + payload = { + "prompt": curr_text.strip(), + "temperature": 0.0, + "top_k": 1, + "max_tokens": 1, + "logprobs": 1, + "stream": False, + } + response = requests.post( + endpoint, + json=payload, + headers={"Authorization": f"Bearer {api_key}"} if api_key else {}, + ) + response.raise_for_status() + data = response.json() + data["__index"] = i_cur # add index for easier debugging later + data = json.dumps(data) + f.write(f"{data}\n") + n_done += 1 + i_cur += 1 + logger.info( + f"\n\n{data}\n\n[Step: {n_done}/{n_total} | Word: {i_cur}/{i_total}]" + ) + logger.info(f"Logits dumped to {output_path}") + + +def get_token_logprobs(data: dict): + logprobs = data["choices"][0]["logprobs"] + if "content" in logprobs: + # llama.cpp case + top = logprobs["content"][0]["top_logprobs"][0] + return top["token"], top["logprob"] + else: + # vllm case + tokens = logprobs["tokens"] + token_logprobs = logprobs["token_logprobs"] + return tokens[0], token_logprobs[0] + + +def clean_text(text: str) -> str: + return ( + "'" + + text.replace("\n", "\\n") + .replace("\t", "\\t") + .replace("\r", "\\r") + .replace("|", "\\|") + + "'" + ) + + +def compare_logits(input1: Path, input2: Path, output_path: Path): + with input1.open("r") as f1, input2.open("r") as f2, output_path.open("w") as fout: + lines1 = f1.readlines() + lines2 = f2.readlines() + + tab_header = [ + "idx", + input1.name, + "logprob_1", + input2.name, + "logprob_2", + "diff (abs)", + ] + tab_entries = [] + tab_max_widths = [len(h) for h in tab_header] + + assert len(lines1) == len( + lines2 + ), "Input files must have the same number of lines." + + fout.write("# Logits Comparison Report\n\n") + for i, (line1, line2) in enumerate(zip(lines1, lines2)): + if not line1.strip() or not line2.strip(): + continue # skip empty lines + + data1 = json.loads(line1) + data2 = json.loads(line2) + + idx1 = data1.get("__index", -1) + idx2 = data2.get("__index", -1) + if idx1 != idx2: + logger.warning( + f"Warning: Mismatched indices at line {i}: {idx1} vs {idx2}" + ) + + token1, logprob1 = get_token_logprobs(data1) + token2, logprob2 = get_token_logprobs(data2) + + token1 = clean_text(token1) + token2 = clean_text(token2) + abs_diff = abs(logprob1 - logprob2) + + tab_entries.append( + ( + str(idx1 + 1), + token1, + f"{logprob1:.4f}", + token2, + f"{logprob2:.4f}", + f"{(abs_diff):.4f}", + ) + ) + + for i in range(len(tab_entries)): + for j in range(len(tab_header)): + tab_max_widths[j] = max(tab_max_widths[j], len(tab_entries[i][j])) + + output = "" + for j in range(len(tab_header)): + output += f"| {tab_header[j]:<{tab_max_widths[j]}} " + output += "|\n" + for j in range(len(tab_header)): + output += f"|{'-' * (tab_max_widths[j] + 2)}" + output += "|\n" + for entry in tab_entries: + for j in range(len(tab_header)): + output += f"| {entry[j]:<{tab_max_widths[j]}} " + output += "|\n" + + logger.info("\n" + output) + fout.write(output) + logger.info(f"Report written to {output_path}") + + +def parse_pattern(pattern: str) -> list[tuple[bool, int]]: + parts = pattern.split(",") + result = [] + for i, part in enumerate(parts): + n = int(part) + if i % 2 == 0: + result.append((True, n)) # get n words + else: + result.append((False, n)) # skip n words + return result + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=DESCRIPTION, formatter_class=argparse.RawTextHelpFormatter + ) + subparsers = parser.add_subparsers( + dest="verb", required=True, help="action to perform" + ) + + # dump subcommand + parser_dump = subparsers.add_parser("dump", help="dump logits from an endpoint") + parser_dump.add_argument( + "output", type=Path, help="output path for dumped logits (.log)" + ) + parser_dump.add_argument( + "endpoint", type=str, help="OAI-compat /completions endpoint" + ) + parser_dump.add_argument( + "--api-key", + type=str, + default=None, + help="API key for authentication (if required)", + ) + parser_dump.add_argument( + "--file", + type=Path, + default=None, + help="File containing prompt to use instead of the default", + ) + parser_dump.add_argument( + "--pattern", + type=str, + default="10,1000,10,4000,10", + help="Pattern n_get,n_skip,... where n_get is number of words to get and n_skip is number of words to skip (num of words, NOT num of tokens)", + ) + + # compare subcommand + parser_compare = subparsers.add_parser( + "compare", help="compare two dumped logits files" + ) + parser_compare.add_argument("input1", type=Path, help="first input file (.log)") + parser_compare.add_argument("input2", type=Path, help="second input file (.log)") + parser_compare.add_argument( + "output", type=Path, help="output path for comparison report (.md)" + ) + + try: + return parser.parse_args() + except Exception as e: + parser.print_help() + raise e + + +def main(): + args = parse_args() + + if args.verb == "dump": + pattern = parse_pattern(args.pattern) + input_length = sum(n for _, n in pattern) + input_words = generate_input_prompt(input_length) + if args.file is not None: + with args.file.open("r") as f: + input_words = f.read().strip().split(" ") + if input_length < sum(n for _, n in pattern): + raise ValueError( + f"Input file has only {input_length} words, but pattern requires at least {input_length} words." + ) + input_length = len(input_words) + logger.info(f"Using {input_length} words") + dump_logits(args.endpoint, args.output, input_words, pattern, args.api_key) + elif args.verb == "compare": + compare_logits(args.input1, args.input2, args.output) + else: + raise ValueError(f"Unknown verb: {args.verb}") + + +if __name__ == "__main__": + main() diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 2692297dca..9914b3276b 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1318,6 +1318,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark) LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); #endif + synchronize(); buf_output = nullptr; logits = nullptr; embd = nullptr; diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp index 90750b20c2..468d325e22 100644 --- a/tests/test-arg-parser.cpp +++ b/tests/test-arg-parser.cpp @@ -72,6 +72,10 @@ int main(void) { argv = {"binary_name", "--draft", "123"}; assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_EMBEDDING)); + // negated arg + argv = {"binary_name", "--no-mmap"}; + assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON)); + printf("test-arg-parser: test valid usage\n\n"); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 308e752b1d..416218b5b8 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7652,6 +7652,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f)); + for (float max_bias : {0.0f, 8.0f}) { for (float scale : {1.0f, 0.1f}) { for (int64_t ne0 : {16, 1024}) { @@ -7971,8 +7974,12 @@ static std::vector> make_test_cases_eval() { for (bool with_norm : {false, true}) { test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm)); + test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm)); test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm)); + test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm)); + test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm)); test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm)); + test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm)); } test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true)); diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 6c618a673c..3690c0bb82 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #ifdef _WIN32 #include @@ -171,7 +172,7 @@ server_presets::server_presets(int argc, char ** argv, common_params & base_para } // read base args from router's argv - common_params_parse(argc, argv, LLAMA_EXAMPLE_SERVER, base_args); + common_params_to_map(argc, argv, LLAMA_EXAMPLE_SERVER, base_args); // remove any router-controlled args from base_args for (const auto & cargs : control_args) { diff --git a/vendor/cpp-httplib/CMakeLists.txt b/vendor/cpp-httplib/CMakeLists.txt index e90e8e2d1b..8f0d15d1fd 100644 --- a/vendor/cpp-httplib/CMakeLists.txt +++ b/vendor/cpp-httplib/CMakeLists.txt @@ -11,8 +11,9 @@ endif() target_link_libraries (${TARGET} PRIVATE Threads::Threads) if (WIN32 AND NOT MSVC) - target_link_libraries(${TARGET} PUBLIC ws2_32) + target_link_libraries(${TARGET} PRIVATE ws2_32) endif() + target_compile_features(${TARGET} PRIVATE cxx_std_17) target_compile_definitions(${TARGET} PRIVATE