Merge branch 'ggml-org:master' into power-law-sampler
This commit is contained in:
commit
67a733670e
|
|
@ -504,7 +504,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
|||
|
||||
// model is required (except for server)
|
||||
// TODO @ngxson : maybe show a list of available models in CLI in this case
|
||||
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage) {
|
||||
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage && !params.completion) {
|
||||
throw std::invalid_argument("error: --model is required\n");
|
||||
}
|
||||
|
||||
|
|
@ -639,6 +639,7 @@ static void common_params_print_completion(common_params_context & ctx_arg) {
|
|||
"llama-batched-bench",
|
||||
"llama-bench",
|
||||
"llama-cli",
|
||||
"llama-completion",
|
||||
"llama-convert-llama2c-to-ggml",
|
||||
"llama-cvector-generator",
|
||||
"llama-embedding",
|
||||
|
|
@ -723,7 +724,7 @@ static void add_rpc_devices(const std::string & servers) {
|
|||
}
|
||||
}
|
||||
|
||||
bool common_params_parse(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map) {
|
||||
bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map) {
|
||||
common_params dummy_params;
|
||||
common_params_context ctx_arg = common_params_parser_init(dummy_params, ex, nullptr);
|
||||
|
||||
|
|
@ -732,6 +733,9 @@ bool common_params_parse(int argc, char ** argv, llama_example ex, std::map<comm
|
|||
for (const auto & arg : opt.args) {
|
||||
arg_to_options[arg] = &opt;
|
||||
}
|
||||
for (const auto & arg : opt.args_neg) {
|
||||
arg_to_options[arg] = &opt;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO @ngxson : find a way to deduplicate this code
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
|
|||
|
||||
// parse input arguments from CLI into a map
|
||||
// TODO: support repeated args in the future
|
||||
bool common_params_parse(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map);
|
||||
bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map);
|
||||
|
||||
// initialize argument parser context - used by test-arg-parser and preset
|
||||
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
# Add utils directory to path for direct script execution
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "utils"))
|
||||
from common import get_model_name_from_env_path # type: ignore[import-not-found]
|
||||
|
||||
def quick_logits_check(pytorch_file, llamacpp_file):
|
||||
"""Lightweight sanity check before NMSE"""
|
||||
|
||||
|
|
@ -35,20 +38,13 @@ def quick_logits_check(pytorch_file, llamacpp_file):
|
|||
return True
|
||||
|
||||
def main():
|
||||
model_path = os.getenv('MODEL_PATH')
|
||||
if not model_path:
|
||||
print("Error: MODEL_PATH environment variable not set")
|
||||
sys.exit(1)
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"Error: Model file not found: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
model_name = os.path.basename(model_path)
|
||||
model_name = get_model_name_from_env_path('MODEL_PATH')
|
||||
data_dir = Path("data")
|
||||
|
||||
pytorch_file = data_dir / f"pytorch-{model_name}.bin"
|
||||
llamacpp_file = data_dir / f"llamacpp-{model_name}.bin"
|
||||
|
||||
llamacpp_model_name = get_model_name_from_env_path('CONVERTED_MODEL')
|
||||
print(f"Using converted model: {llamacpp_model_name}")
|
||||
llamacpp_file = data_dir / f"llamacpp-{llamacpp_model_name}.bin"
|
||||
|
||||
if not pytorch_file.exists():
|
||||
print(f"Error: PyTorch logits file not found: {pytorch_file}")
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import sys
|
|||
import os
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from common import get_model_name_from_env_path # type: ignore[import-not-found]
|
||||
|
||||
def calculate_nmse(reference, test):
|
||||
mse = np.mean((test - reference) ** 2)
|
||||
|
|
@ -67,11 +68,13 @@ def main():
|
|||
parser.add_argument('-m', '--model-path', required=True, help='Path to the model directory')
|
||||
args = parser.parse_args()
|
||||
|
||||
model_name = os.path.basename(args.model_path)
|
||||
model_name = get_model_name_from_env_path('MODEL_PATH')
|
||||
data_dir = Path("data")
|
||||
|
||||
pytorch_file = data_dir / f"pytorch-{model_name}.bin"
|
||||
llamacpp_file = data_dir / f"llamacpp-{model_name}.bin"
|
||||
|
||||
llamacpp_model_name = get_model_name_from_env_path('CONVERTED_MODEL')
|
||||
llamacpp_file = data_dir / f"llamacpp-{llamacpp_model_name}.bin"
|
||||
|
||||
print(f"Model name: {model_name}")
|
||||
print(f"PyTorch logits file: {pytorch_file}")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
def get_model_name_from_env_path(env_path_name):
|
||||
model_path = os.getenv(env_path_name)
|
||||
if not model_path:
|
||||
print(f"Error: {env_path_name} environment variable not set")
|
||||
sys.exit(1)
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"Error: Model file not found: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
name = os.path.basename(os.path.normpath(model_path))
|
||||
if name.endswith(".gguf"):
|
||||
name = name[:-5]
|
||||
|
||||
return name
|
||||
|
|
@ -255,6 +255,8 @@ int main(int argc, char ** argv) {
|
|||
LOG_INF("target:\n\n");
|
||||
common_perf_print(ctx_tgt, smpl);
|
||||
|
||||
llama_batch_free(batch_tgt);
|
||||
|
||||
common_sampler_free(smpl);
|
||||
common_speculative_free(spec);
|
||||
|
||||
|
|
|
|||
|
|
@ -659,6 +659,7 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_cos_f32;
|
||||
vk_pipeline pipeline_log[2];
|
||||
vk_pipeline pipeline_tri[2];
|
||||
vk_pipeline pipeline_diag[2];
|
||||
vk_pipeline pipeline_clamp_f32;
|
||||
vk_pipeline pipeline_pad_f32;
|
||||
vk_pipeline pipeline_roll_f32;
|
||||
|
|
@ -722,6 +723,11 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
|
||||
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
|
||||
vk_pipeline pipeline_soft_max_back_f32;
|
||||
|
||||
vk_pipeline pipeline_soft_max_large1_f32, pipeline_soft_max_large1_f32_f16;
|
||||
vk_pipeline pipeline_soft_max_large2_f32, pipeline_soft_max_large2_f32_f16;
|
||||
vk_pipeline pipeline_soft_max_large3_f32, pipeline_soft_max_large3_f32_f16;
|
||||
|
||||
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16;
|
||||
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
|
||||
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
|
||||
|
|
@ -757,7 +763,8 @@ struct vk_device_struct {
|
|||
|
||||
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
||||
|
||||
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT];
|
||||
// [2] is for whether to take n_experts from spec constant (0) or push constant (1)
|
||||
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2];
|
||||
|
||||
std::vector<vk_pipeline_ref> all_pipelines;
|
||||
|
||||
|
|
@ -1149,6 +1156,7 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
|
|||
|
||||
struct vk_op_topk_moe_push_constants {
|
||||
uint32_t n_rows;
|
||||
uint32_t n_experts_push;
|
||||
uint32_t n_expert_used;
|
||||
float clamp_min;
|
||||
float clamp_max;
|
||||
|
|
@ -3730,6 +3738,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32], "get_rows_i32", get_rows_i32_len, get_rows_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||
|
|
@ -3917,6 +3926,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
|
@ -3996,6 +4008,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32, "soft_max_large1_f32", soft_max_large1_f32_len, soft_max_large1_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32, "soft_max_large2_f32", soft_max_large2_f32_len, soft_max_large2_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32, "soft_max_large3_f32", soft_max_large3_f32_len, soft_max_large3_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32_f16, "soft_max_large1_f32_f16", soft_max_large1_f32_f16_len, soft_max_large1_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32_f16, "soft_max_large2_f32_f16", soft_max_large2_f32_f16_len, soft_max_large2_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32_f16, "soft_max_large3_f32_f16", soft_max_large3_f32_f16_len, soft_max_large3_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
|
|
@ -4204,10 +4223,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true, device->subgroup_size);
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true, device->subgroup_size);
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true, device->subgroup_size);
|
||||
for (uint32_t use_push = 0; use_push < 2; ++use_push) {
|
||||
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX][use_push], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0, use_push}, 1, true, true, device->subgroup_size);
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM][use_push], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0, use_push}, 1, true, true, device->subgroup_size);
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX][use_push], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1, use_push}, 1, true, true, device->subgroup_size);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &c : compiles) {
|
||||
|
|
@ -8274,6 +8295,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
switch (op) {
|
||||
case GGML_OP_GET_ROWS:
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||
if (src0->type == GGML_TYPE_I32) {
|
||||
// i32 src only supports i32 result
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_I32);
|
||||
return ctx->device->pipeline_get_rows[src0->type];
|
||||
}
|
||||
if (dst->type == GGML_TYPE_F16) {
|
||||
return ctx->device->pipeline_get_rows[src0->type];
|
||||
}
|
||||
|
|
@ -8400,6 +8426,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_DIAG:
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_diag[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_CLAMP:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_clamp_f32;
|
||||
|
|
@ -8554,7 +8586,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
||||
GGML_ASSERT(idx < num_topk_moe_pipelines);
|
||||
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
||||
return ctx->device->pipeline_topk_moe[idx][mode];
|
||||
// use n_experts from push constant if it's not equal to the power of two spec constant
|
||||
bool use_push = dst->ne[0] != (1u << idx);
|
||||
return ctx->device->pipeline_topk_moe[idx][mode][use_push];
|
||||
}
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
||||
|
|
@ -9091,6 +9125,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
case GGML_OP_COS:
|
||||
case GGML_OP_LOG:
|
||||
case GGML_OP_TRI:
|
||||
case GGML_OP_DIAG:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ROLL:
|
||||
|
|
@ -9778,6 +9813,12 @@ static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p));
|
||||
}
|
||||
|
||||
static void ggml_vk_diag(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
|
||||
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG, std::move(p));
|
||||
}
|
||||
|
||||
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
||||
p.param1 = ggml_get_op_params_f32(dst, 0);
|
||||
|
|
@ -10111,7 +10152,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, {
|
||||
vk_op_soft_max_push_constants pc {
|
||||
ncols,
|
||||
src1 != nullptr ? nrows_y : (uint32_t)0,
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
|
||||
|
|
@ -10122,7 +10163,55 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||
n_head_log2,
|
||||
nrows_x,
|
||||
src2 != nullptr
|
||||
});
|
||||
};
|
||||
|
||||
if (ncols <= 16384) {
|
||||
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, std::move(pc));
|
||||
} else {
|
||||
|
||||
vk_subbuffer buf_a = ggml_vk_tensor_subbuffer(ctx, src0);
|
||||
vk_subbuffer buf_b = src1 ? ggml_vk_tensor_subbuffer(ctx, src1) : buf_a;
|
||||
vk_subbuffer buf_c = src2 ? ggml_vk_tensor_subbuffer(ctx, src2) : buf_a;
|
||||
vk_subbuffer buf_d = ggml_vk_tensor_subbuffer(ctx, dst);
|
||||
|
||||
uint32_t elems_per_wg = 128 * 4;
|
||||
uint32_t num_wgs = CEIL_DIV(ncols, elems_per_wg);
|
||||
size_t tmp_size = num_wgs * nrows_x * sizeof(float);
|
||||
|
||||
if (ctx->prealloc_size_x < tmp_size) {
|
||||
ctx->prealloc_size_x = tmp_size;
|
||||
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||
}
|
||||
if (ctx->prealloc_size_y < tmp_size) {
|
||||
ctx->prealloc_size_y = tmp_size;
|
||||
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||
}
|
||||
if (ctx->prealloc_x_need_sync || ctx->prealloc_y_need_sync) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
|
||||
vk_subbuffer buf_x = { ctx->prealloc_x, 0, tmp_size };
|
||||
vk_subbuffer buf_y = { ctx->prealloc_y, 0, tmp_size };
|
||||
|
||||
std::array<uint32_t, 3> elements = { num_wgs, nrows_x, 1 };
|
||||
|
||||
vk_pipeline pipeline1 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large1_f32_f16 : ctx->device->pipeline_soft_max_large1_f32;
|
||||
vk_pipeline pipeline2 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large2_f32_f16 : ctx->device->pipeline_soft_max_large2_f32;
|
||||
vk_pipeline pipeline3 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large3_f32_f16 : ctx->device->pipeline_soft_max_large3_f32;
|
||||
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline3, 1);
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline3, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
|
||||
|
||||
ctx->prealloc_x_need_sync = true;
|
||||
ctx->prealloc_y_need_sync = true;
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
|
|
@ -10158,6 +10247,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||
|
||||
vk_op_topk_moe_push_constants pc {};
|
||||
pc.n_rows = n_rows;
|
||||
pc.n_experts_push = n_experts;
|
||||
pc.n_expert_used = n_expert_used;
|
||||
if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
|
||||
ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
|
||||
|
|
@ -11857,6 +11947,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
case GGML_OP_TRI:
|
||||
ggml_vk_tri(ctx, compute_ctx, src0, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_DIAG:
|
||||
ggml_vk_diag(ctx, compute_ctx, src0, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_CLAMP:
|
||||
ggml_vk_clamp(ctx, compute_ctx, src0, node);
|
||||
|
|
@ -12832,8 +12926,7 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
|
|||
}
|
||||
|
||||
const int n_expert = softmax->ne[0];
|
||||
// n_expert must be a power of 2
|
||||
if (!is_pow2(n_expert) || n_expert > (1 << (num_topk_moe_pipelines-1))) {
|
||||
if (n_expert > (1 << (num_topk_moe_pipelines-1))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -13877,6 +13970,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_I32:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
|
@ -14001,6 +14095,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_LOG:
|
||||
case GGML_OP_TRI:
|
||||
case GGML_OP_DIAG:
|
||||
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||
op->type == op->src[0]->type;
|
||||
case GGML_OP_ARGSORT:
|
||||
|
|
@ -14591,6 +14686,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
|
||||
} else if (tensor->op == GGML_OP_TRI) {
|
||||
tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
|
||||
} else if (tensor->op == GGML_OP_DIAG) {
|
||||
tensor_clone = ggml_diag(ggml_ctx, src_clone[0]);
|
||||
} else if (tensor->op == GGML_OP_CLAMP) {
|
||||
const float * params = (const float *)tensor->op_params;
|
||||
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,29 @@
|
|||
#version 450
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
|
||||
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
|
||||
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
|
||||
const uint i12_offset = i12*p.ne11*p.ne10;
|
||||
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
|
||||
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
|
||||
|
||||
if (i10 == i11) {
|
||||
const float val = float(data_a[get_aoffset() + i13*p.nb03 + i12*p.nb02 + 0*p.nb01 + i10*p.nb00]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val);
|
||||
} else {
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0);
|
||||
}
|
||||
}
|
||||
|
|
@ -26,9 +26,9 @@ void main() {
|
|||
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
|
||||
|
||||
#if defined(DATA_A_BF16)
|
||||
FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
|
||||
TEMP_TYPE v = TEMP_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
|
||||
#else
|
||||
FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
|
||||
TEMP_TYPE v = TEMP_TYPE(data_a[a_offset + i00]);
|
||||
#endif
|
||||
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
||||
data_d[d_offset + i00] = D_TYPE(v);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,62 @@
|
|||
#version 450
|
||||
|
||||
#include "soft_max_large_common.glsl"
|
||||
|
||||
void main() {
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint rowx = gl_WorkGroupID.y;
|
||||
const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;
|
||||
|
||||
const uint32_t i03 = rowx / (p.ne01 * p.ne02);
|
||||
const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
|
||||
const uint32_t i01 = rowx % p.ne01;
|
||||
|
||||
uint rowy_start = 0;
|
||||
if (p.KY > 0) {
|
||||
rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
|
||||
}
|
||||
|
||||
if (rowx >= p.nrows_x) {
|
||||
return;
|
||||
}
|
||||
|
||||
float slope = get_slope(rowx);
|
||||
|
||||
// Find max
|
||||
FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
|
||||
|
||||
[[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
|
||||
const uint col = col0 + tid;
|
||||
|
||||
FLOAT_TYPE a = FLOAT_TYPE(0);
|
||||
if (col < p.KX) {
|
||||
a = data_a[rowx * p.KX + col];
|
||||
}
|
||||
|
||||
FLOAT_TYPE b = FLOAT_TYPE(0);
|
||||
if (p.KY > 0 && col < p.KX) {
|
||||
b = data_b[rowy_start + col];
|
||||
}
|
||||
|
||||
FLOAT_TYPE v = a * p.scale + slope * b;
|
||||
|
||||
if (col < p.KX) {
|
||||
max_val = max(max_val, v);
|
||||
}
|
||||
}
|
||||
|
||||
// reduce across the workgroup
|
||||
vals[tid] = max_val;
|
||||
barrier();
|
||||
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
vals[tid] = max(vals[tid], vals[tid + s]);
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
max_val = vals[0];
|
||||
data_m[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = max_val;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
#version 450
|
||||
|
||||
#include "soft_max_large_common.glsl"
|
||||
|
||||
void main() {
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint rowx = gl_WorkGroupID.y;
|
||||
const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;
|
||||
|
||||
const uint32_t i03 = rowx / (p.ne01 * p.ne02);
|
||||
const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
|
||||
const uint32_t i01 = rowx % p.ne01;
|
||||
|
||||
uint rowy_start = 0;
|
||||
if (p.KY > 0) {
|
||||
rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
|
||||
}
|
||||
|
||||
if (rowx >= p.nrows_x) {
|
||||
return;
|
||||
}
|
||||
|
||||
float slope = get_slope(rowx);
|
||||
|
||||
// Find max
|
||||
FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
|
||||
|
||||
[[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) {
|
||||
if (i + tid < gl_NumWorkGroups.x) {
|
||||
max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]);
|
||||
}
|
||||
}
|
||||
|
||||
// reduce across the workgroup
|
||||
vals[tid] = max_val;
|
||||
barrier();
|
||||
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
vals[tid] = max(max_val, vals[tid + s]);
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
max_val = vals[0];
|
||||
barrier();
|
||||
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0f);
|
||||
|
||||
// Compute sum{exp(x - max)}
|
||||
[[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
|
||||
const uint col = col0 + tid;
|
||||
|
||||
if (col >= p.KX) {
|
||||
break;
|
||||
}
|
||||
|
||||
// compute exp(a*scale+b*slope), add it to sum
|
||||
const uint i = rowx * p.KX + col;
|
||||
FLOAT_TYPE val;
|
||||
val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val);
|
||||
sum += val;
|
||||
data_d[i] = D_TYPE(val);
|
||||
}
|
||||
|
||||
// reduce across the workgroup
|
||||
vals[tid] = sum;
|
||||
barrier();
|
||||
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
vals[tid] += vals[tid + s];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
sum = vals[0];
|
||||
data_s[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = sum;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
#version 450
|
||||
|
||||
#include "soft_max_large_common.glsl"
|
||||
|
||||
shared FLOAT_TYPE sumsh[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint rowx = gl_WorkGroupID.y;
|
||||
const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;
|
||||
|
||||
const uint32_t i03 = rowx / (p.ne01 * p.ne02);
|
||||
const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
|
||||
const uint32_t i01 = rowx % p.ne01;
|
||||
|
||||
uint rowy_start = 0;
|
||||
if (p.KY > 0) {
|
||||
rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
|
||||
}
|
||||
|
||||
if (rowx >= p.nrows_x) {
|
||||
return;
|
||||
}
|
||||
|
||||
FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0f);
|
||||
|
||||
[[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) {
|
||||
if (i + tid < gl_NumWorkGroups.x) {
|
||||
max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]);
|
||||
sum += data_s[rowx * gl_NumWorkGroups.x + i + tid];
|
||||
}
|
||||
}
|
||||
|
||||
// reduce across the workgroup
|
||||
vals[tid] = max_val;
|
||||
sumsh[tid] = sum;
|
||||
barrier();
|
||||
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
vals[tid] = max(max_val, vals[tid + s]);
|
||||
sumsh[tid] += sumsh[tid + s];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
max_val = vals[0];
|
||||
sum = sumsh[0];
|
||||
|
||||
if (p.has_sinks != 0) {
|
||||
sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val));
|
||||
}
|
||||
|
||||
FLOAT_TYPE rcpdivisor = 1.0/sum;
|
||||
|
||||
[[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
|
||||
const uint col = col0 + tid;
|
||||
|
||||
if (col >= p.KX) {
|
||||
continue;
|
||||
}
|
||||
|
||||
data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint KX;
|
||||
uint KY;
|
||||
uint ne00;
|
||||
uint ne01;
|
||||
uint ne02;
|
||||
uint ne12;
|
||||
uint ne13;
|
||||
uint nb11;
|
||||
uint nb12;
|
||||
uint nb13;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
float m1;
|
||||
uint n_head_log2;
|
||||
uint nrows_x;
|
||||
uint has_sinks;
|
||||
} p;
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
layout(constant_id = 0) const uint BLOCK_SIZE = 128;
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
layout(constant_id = 1) const uint num_iters = 4;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
|
||||
layout (binding = 2) readonly buffer Z {float data_c[];};
|
||||
layout (binding = 3) buffer D {D_TYPE data_d[];};
|
||||
layout (binding = 4) buffer M {float data_m[];};
|
||||
layout (binding = 5) buffer S {float data_s[];};
|
||||
|
||||
shared FLOAT_TYPE vals[BLOCK_SIZE];
|
||||
|
||||
float get_slope(uint rowx) {
|
||||
float slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
if (p.max_bias > 0.0f) {
|
||||
const uint h = (rowx / p.ne01) % p.ne02; // head index
|
||||
|
||||
const float base = h < p.n_head_log2 ? p.m0 : p.m1;
|
||||
const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exp);
|
||||
}
|
||||
|
||||
return slope;
|
||||
}
|
||||
|
|
@ -10,6 +10,7 @@
|
|||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint n_rows;
|
||||
uint n_experts_push;
|
||||
uint n_expert_used;
|
||||
float clamp_min;
|
||||
float clamp_max;
|
||||
|
|
@ -18,11 +19,16 @@ layout (push_constant) uniform parameter
|
|||
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
||||
|
||||
layout(constant_id = 0) const uint WARP_SIZE = 32;
|
||||
layout(constant_id = 1) const uint n_experts = 512;
|
||||
layout(constant_id = 1) const uint n_experts_spec = 512;
|
||||
layout(constant_id = 2) const bool with_norm = true;
|
||||
layout(constant_id = 3) const bool late_softmax = false;
|
||||
layout(constant_id = 4) const bool nexperts_use_push = false;
|
||||
|
||||
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
|
||||
uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;
|
||||
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE);
|
||||
|
||||
layout (binding = 0, std430) readonly buffer Logits {float logits[];};
|
||||
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
|
||||
|
|
@ -94,7 +100,7 @@ void main() {
|
|||
}
|
||||
|
||||
if (!late_softmax) {
|
||||
softmax_warp_inplace(wt, n_experts, lane, false);
|
||||
softmax_warp_inplace(wt, n_experts, lane, nexperts_use_push);
|
||||
}
|
||||
|
||||
// at this point, each thread holds a portion of softmax,
|
||||
|
|
|
|||
|
|
@ -704,13 +704,15 @@ void process_shaders() {
|
|||
shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp";
|
||||
|
||||
if (tname == "f16") {
|
||||
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
|
||||
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{"TEMP_TYPE", "FLOAT_TYPE"}, {data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
|
||||
} else {
|
||||
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
|
||||
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{"TEMP_TYPE", "FLOAT_TYPE"}, {data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
|
||||
}
|
||||
string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{"TEMP_TYPE", "FLOAT_TYPE"}, {data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
|
||||
}
|
||||
|
||||
string_to_spv("get_rows_i32", "get_rows.comp", {{"TEMP_TYPE", "uint"}, {"A_TYPE", "uint"}, {"B_TYPE", "int"}, {"D_TYPE", "uint"}});
|
||||
|
||||
string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
|
||||
string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
|
||||
|
|
@ -854,6 +856,8 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("diag_f16", "diag.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("diag_f32", "diag.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
|
@ -899,6 +903,13 @@ void process_shaders() {
|
|||
string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("soft_max_large1_f32", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("soft_max_large2_f32", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("soft_max_large3_f32", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("soft_max_large1_f32_f16", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("soft_max_large2_f32_f16", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("soft_max_large3_f32_f16", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
|
||||
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
{
|
||||
"extraPaths": ["gguf-py"],
|
||||
"extraPaths": ["gguf-py", "examples/model-conversion/scripts"],
|
||||
"pythonVersion": "3.9",
|
||||
"pythonPlatform": "All",
|
||||
"reportUnusedImport": "warning",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,281 @@
|
|||
import argparse
|
||||
import requests
|
||||
import json
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("compare-logprobs")
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
DESCRIPTION = """
|
||||
Compare logits between llama.cpp and another inference engine using OpenAI-compatible server endpoints.
|
||||
|
||||
Unlike compare-logits.py, it allows dumping logits from a hosted API endpoint. Useful when it's not possible to run both models locally.
|
||||
|
||||
Example usage:
|
||||
Step 1: Dump logits from two different servers
|
||||
python scripts/compare-logprobs.py dump logits_llama.log http://localhost:8080/v1/completions
|
||||
python scripts/compare-logprobs.py dump logits_other.log http://other-engine:8000/v1/completions
|
||||
|
||||
(optionally, you can add --api-key <key> if the endpoint requires authentication)
|
||||
|
||||
Step 2: Compare the dumped logits
|
||||
python scripts/compare-logprobs.py compare logits_llama.log logits_other.log report.md
|
||||
"""
|
||||
|
||||
|
||||
def generate_input_prompt(length: int) -> list[str]:
|
||||
CORPUS = """
|
||||
You are an advanced AI assistant capable of using tools to gather information, perform calculations, or execute tasks. Always think step by step before responding. If a user's query requires external data, computation, or actions beyond your internal knowledge, use the appropriate tools via function calls.
|
||||
|
||||
### Tool Call Format:
|
||||
When you need to use a tool, output the call in this exact XML format. Include the opening and closing tags. Do not escape arguments; they will be parsed as plain text.
|
||||
|
||||
You can make multiple calls in one go by placing them one after another.
|
||||
"""
|
||||
words = [w.strip() for w in CORPUS.strip().split(" ")]
|
||||
words = [w for w in words if len(w) > 0] # filter out empty strings
|
||||
while len(words) < length:
|
||||
words += words
|
||||
return words[:length]
|
||||
|
||||
|
||||
def dump_logits(
|
||||
endpoint: str,
|
||||
output_path: Path,
|
||||
input_words: list[str],
|
||||
pattern: list[tuple[bool, int]],
|
||||
api_key=None,
|
||||
):
|
||||
logger.info(f"Dumping logits to {output_path} from endpoint {endpoint}...")
|
||||
words = input_words
|
||||
curr_text = ""
|
||||
n_total = sum(n for get, n in pattern if get)
|
||||
n_done = 0
|
||||
i_cur = 0
|
||||
i_total = len(words)
|
||||
with output_path.open("w") as f:
|
||||
for get, n in pattern:
|
||||
if not get:
|
||||
# skip n words
|
||||
for i in range(n):
|
||||
curr_text += words.pop(0) + " "
|
||||
i_cur += 1
|
||||
continue
|
||||
# get n words
|
||||
for i in range(n):
|
||||
curr_text += words.pop(0) + " "
|
||||
payload = {
|
||||
"prompt": curr_text.strip(),
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"max_tokens": 1,
|
||||
"logprobs": 1,
|
||||
"stream": False,
|
||||
}
|
||||
response = requests.post(
|
||||
endpoint,
|
||||
json=payload,
|
||||
headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
data["__index"] = i_cur # add index for easier debugging later
|
||||
data = json.dumps(data)
|
||||
f.write(f"{data}\n")
|
||||
n_done += 1
|
||||
i_cur += 1
|
||||
logger.info(
|
||||
f"\n\n{data}\n\n[Step: {n_done}/{n_total} | Word: {i_cur}/{i_total}]"
|
||||
)
|
||||
logger.info(f"Logits dumped to {output_path}")
|
||||
|
||||
|
||||
def get_token_logprobs(data: dict):
|
||||
logprobs = data["choices"][0]["logprobs"]
|
||||
if "content" in logprobs:
|
||||
# llama.cpp case
|
||||
top = logprobs["content"][0]["top_logprobs"][0]
|
||||
return top["token"], top["logprob"]
|
||||
else:
|
||||
# vllm case
|
||||
tokens = logprobs["tokens"]
|
||||
token_logprobs = logprobs["token_logprobs"]
|
||||
return tokens[0], token_logprobs[0]
|
||||
|
||||
|
||||
def clean_text(text: str) -> str:
|
||||
return (
|
||||
"'"
|
||||
+ text.replace("\n", "\\n")
|
||||
.replace("\t", "\\t")
|
||||
.replace("\r", "\\r")
|
||||
.replace("|", "\\|")
|
||||
+ "'"
|
||||
)
|
||||
|
||||
|
||||
def compare_logits(input1: Path, input2: Path, output_path: Path):
|
||||
with input1.open("r") as f1, input2.open("r") as f2, output_path.open("w") as fout:
|
||||
lines1 = f1.readlines()
|
||||
lines2 = f2.readlines()
|
||||
|
||||
tab_header = [
|
||||
"idx",
|
||||
input1.name,
|
||||
"logprob_1",
|
||||
input2.name,
|
||||
"logprob_2",
|
||||
"diff (abs)",
|
||||
]
|
||||
tab_entries = []
|
||||
tab_max_widths = [len(h) for h in tab_header]
|
||||
|
||||
assert len(lines1) == len(
|
||||
lines2
|
||||
), "Input files must have the same number of lines."
|
||||
|
||||
fout.write("# Logits Comparison Report\n\n")
|
||||
for i, (line1, line2) in enumerate(zip(lines1, lines2)):
|
||||
if not line1.strip() or not line2.strip():
|
||||
continue # skip empty lines
|
||||
|
||||
data1 = json.loads(line1)
|
||||
data2 = json.loads(line2)
|
||||
|
||||
idx1 = data1.get("__index", -1)
|
||||
idx2 = data2.get("__index", -1)
|
||||
if idx1 != idx2:
|
||||
logger.warning(
|
||||
f"Warning: Mismatched indices at line {i}: {idx1} vs {idx2}"
|
||||
)
|
||||
|
||||
token1, logprob1 = get_token_logprobs(data1)
|
||||
token2, logprob2 = get_token_logprobs(data2)
|
||||
|
||||
token1 = clean_text(token1)
|
||||
token2 = clean_text(token2)
|
||||
abs_diff = abs(logprob1 - logprob2)
|
||||
|
||||
tab_entries.append(
|
||||
(
|
||||
str(idx1 + 1),
|
||||
token1,
|
||||
f"{logprob1:.4f}",
|
||||
token2,
|
||||
f"{logprob2:.4f}",
|
||||
f"{(abs_diff):.4f}",
|
||||
)
|
||||
)
|
||||
|
||||
for i in range(len(tab_entries)):
|
||||
for j in range(len(tab_header)):
|
||||
tab_max_widths[j] = max(tab_max_widths[j], len(tab_entries[i][j]))
|
||||
|
||||
output = ""
|
||||
for j in range(len(tab_header)):
|
||||
output += f"| {tab_header[j]:<{tab_max_widths[j]}} "
|
||||
output += "|\n"
|
||||
for j in range(len(tab_header)):
|
||||
output += f"|{'-' * (tab_max_widths[j] + 2)}"
|
||||
output += "|\n"
|
||||
for entry in tab_entries:
|
||||
for j in range(len(tab_header)):
|
||||
output += f"| {entry[j]:<{tab_max_widths[j]}} "
|
||||
output += "|\n"
|
||||
|
||||
logger.info("\n" + output)
|
||||
fout.write(output)
|
||||
logger.info(f"Report written to {output_path}")
|
||||
|
||||
|
||||
def parse_pattern(pattern: str) -> list[tuple[bool, int]]:
|
||||
parts = pattern.split(",")
|
||||
result = []
|
||||
for i, part in enumerate(parts):
|
||||
n = int(part)
|
||||
if i % 2 == 0:
|
||||
result.append((True, n)) # get n words
|
||||
else:
|
||||
result.append((False, n)) # skip n words
|
||||
return result
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=DESCRIPTION, formatter_class=argparse.RawTextHelpFormatter
|
||||
)
|
||||
subparsers = parser.add_subparsers(
|
||||
dest="verb", required=True, help="action to perform"
|
||||
)
|
||||
|
||||
# dump subcommand
|
||||
parser_dump = subparsers.add_parser("dump", help="dump logits from an endpoint")
|
||||
parser_dump.add_argument(
|
||||
"output", type=Path, help="output path for dumped logits (.log)"
|
||||
)
|
||||
parser_dump.add_argument(
|
||||
"endpoint", type=str, help="OAI-compat /completions endpoint"
|
||||
)
|
||||
parser_dump.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API key for authentication (if required)",
|
||||
)
|
||||
parser_dump.add_argument(
|
||||
"--file",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="File containing prompt to use instead of the default",
|
||||
)
|
||||
parser_dump.add_argument(
|
||||
"--pattern",
|
||||
type=str,
|
||||
default="10,1000,10,4000,10",
|
||||
help="Pattern n_get,n_skip,... where n_get is number of words to get and n_skip is number of words to skip (num of words, NOT num of tokens)",
|
||||
)
|
||||
|
||||
# compare subcommand
|
||||
parser_compare = subparsers.add_parser(
|
||||
"compare", help="compare two dumped logits files"
|
||||
)
|
||||
parser_compare.add_argument("input1", type=Path, help="first input file (.log)")
|
||||
parser_compare.add_argument("input2", type=Path, help="second input file (.log)")
|
||||
parser_compare.add_argument(
|
||||
"output", type=Path, help="output path for comparison report (.md)"
|
||||
)
|
||||
|
||||
try:
|
||||
return parser.parse_args()
|
||||
except Exception as e:
|
||||
parser.print_help()
|
||||
raise e
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if args.verb == "dump":
|
||||
pattern = parse_pattern(args.pattern)
|
||||
input_length = sum(n for _, n in pattern)
|
||||
input_words = generate_input_prompt(input_length)
|
||||
if args.file is not None:
|
||||
with args.file.open("r") as f:
|
||||
input_words = f.read().strip().split(" ")
|
||||
if input_length < sum(n for _, n in pattern):
|
||||
raise ValueError(
|
||||
f"Input file has only {input_length} words, but pattern requires at least {input_length} words."
|
||||
)
|
||||
input_length = len(input_words)
|
||||
logger.info(f"Using {input_length} words")
|
||||
dump_logits(args.endpoint, args.output, input_words, pattern, args.api_key)
|
||||
elif args.verb == "compare":
|
||||
compare_logits(args.input1, args.input2, args.output)
|
||||
else:
|
||||
raise ValueError(f"Unknown verb: {args.verb}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1318,6 +1318,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|||
// This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
|
||||
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
|
||||
#endif
|
||||
synchronize();
|
||||
buf_output = nullptr;
|
||||
logits = nullptr;
|
||||
embd = nullptr;
|
||||
|
|
|
|||
|
|
@ -72,6 +72,10 @@ int main(void) {
|
|||
argv = {"binary_name", "--draft", "123"};
|
||||
assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_EMBEDDING));
|
||||
|
||||
// negated arg
|
||||
argv = {"binary_name", "--no-mmap"};
|
||||
assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
|
||||
|
||||
|
||||
printf("test-arg-parser: test valid usage\n\n");
|
||||
|
||||
|
|
|
|||
|
|
@ -7652,6 +7652,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
|
||||
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
|
||||
|
||||
for (float max_bias : {0.0f, 8.0f}) {
|
||||
for (float scale : {1.0f, 0.1f}) {
|
||||
for (int64_t ne0 : {16, 1024}) {
|
||||
|
|
@ -7971,8 +7974,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
|
||||
for (bool with_norm : {false, true}) {
|
||||
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm));
|
||||
test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm));
|
||||
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm));
|
||||
test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm));
|
||||
test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm));
|
||||
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm));
|
||||
test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true));
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <queue>
|
||||
#include <filesystem>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <winsock2.h>
|
||||
|
|
@ -171,7 +172,7 @@ server_presets::server_presets(int argc, char ** argv, common_params & base_para
|
|||
}
|
||||
|
||||
// read base args from router's argv
|
||||
common_params_parse(argc, argv, LLAMA_EXAMPLE_SERVER, base_args);
|
||||
common_params_to_map(argc, argv, LLAMA_EXAMPLE_SERVER, base_args);
|
||||
|
||||
// remove any router-controlled args from base_args
|
||||
for (const auto & cargs : control_args) {
|
||||
|
|
|
|||
|
|
@ -11,8 +11,9 @@ endif()
|
|||
target_link_libraries (${TARGET} PRIVATE Threads::Threads)
|
||||
|
||||
if (WIN32 AND NOT MSVC)
|
||||
target_link_libraries(${TARGET} PUBLIC ws2_32)
|
||||
target_link_libraries(${TARGET} PRIVATE ws2_32)
|
||||
endif()
|
||||
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
target_compile_definitions(${TARGET} PRIVATE
|
||||
|
|
|
|||
Loading…
Reference in New Issue