diff --git a/common/arg.cpp b/common/arg.cpp index b52b3e70b7..c3610d262b 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -854,6 +854,54 @@ bool common_arg_utils::is_autoy(const std::string & value) { return value == "auto" || value == "-1"; } +// Simple CSV parser that handles quoted fields and escaped quotes +// example: +// input: value1,"value, with, commas","value with ""escaped"" quotes",value4 +// output: [value1] [value, with, commas] [value with "escaped" quotes] [value4] +static std::vector parse_csv_row(const std::string& input) { + std::vector fields; + std::string field; + bool in_quotes = false; + + for (size_t i = 0; i < input.length(); ++i) { + char ch = input[i]; + + if (ch == '"') { + if (!in_quotes) { + // start of quoted field (only valid if at beginning of field) + if (!field.empty()) { + // quote appeared in middle of unquoted field, treat as literal + field += '"'; + } else { + in_quotes = true; // start + } + } else { + if (i + 1 < input.length() && input[i + 1] == '"') { + // escaped quote: "" + field += '"'; + ++i; // skip the next quote + } else { + in_quotes = false; // end + } + } + } else if (ch == ',') { + if (in_quotes) { + field += ','; + } else { + fields.push_back(std::move(field)); + field.clear(); + } + } else { + field += ch; + } + } + + // Add the last field + fields.push_back(std::move(field)); + + return fields; +} + common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) { // per-example default params // we define here to make sure it's included in llama-gen-docs @@ -1250,7 +1298,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--in-file"}, "FNAME", "an input file (use comma-separated values to specify multiple files)", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { std::ifstream file(item); if (!file) { throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str())); @@ -2002,7 +2050,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--image", "--audio"}, "FILE", "path to an image or audio file. use with multimodal models, use comma-separated values for multiple files\n", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { params.image.emplace_back(item); } } @@ -2259,37 +2307,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex )); add_opt(common_arg( {"--override-kv"}, "KEY=TYPE:VALUE,...", - "advanced option to override model metadata by key. to specify multiple overrides, either use comma-separated or repeat this argument.\n" + "advanced option to override model metadata by key. to specify multiple overrides, either use comma-separated values.\n" "types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false,tokenizer.ggml.add_eos_token=bool:false", [](common_params & params, const std::string & value) { - std::vector kv_overrides; - - std::string current; - bool escaping = false; - - for (const char c : value) { - if (escaping) { - current.push_back(c); - escaping = false; - } else if (c == '\\') { - escaping = true; - } else if (c == ',') { - kv_overrides.push_back(current); - current.clear(); - } else { - current.push_back(c); - } - } - - if (escaping) { - current.push_back('\\'); - } - - kv_overrides.push_back(current); - - for (const auto & kv_override : kv_overrides) { - if (!string_parse_kv_override(kv_override.c_str(), params.kv_overrides)) { - throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", kv_override.c_str())); + for (const auto & item : parse_csv_row(value)) { + if (!string_parse_kv_override(item.c_str(), params.kv_overrides)) { + throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", item.c_str())); } } } @@ -2306,7 +2329,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--lora"}, "FNAME", "path to LoRA adapter (use comma-separated values to load multiple adapters)", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { params.lora_adapters.push_back({ item, 1.0, "", "", nullptr }); } } @@ -2317,7 +2340,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "path to LoRA adapter with user defined scaling (format: FNAME:SCALE,...)\n" "note: use comma-separated values", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { auto parts = string_split(item, ':'); if (parts.size() != 2) { throw std::invalid_argument("lora-scaled format: FNAME:SCALE"); @@ -2331,7 +2354,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--control-vector"}, "FNAME", "add a control vector\nnote: use comma-separated values to add multiple control vectors", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { params.control_vectors.push_back({ 1.0f, item, }); } } @@ -2341,7 +2364,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "add a control vector with user defined scaling SCALE\n" "note: use comma-separated values (format: FNAME:SCALE,...)", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { auto parts = string_split(item, ':'); if (parts.size() != 2) { throw std::invalid_argument("control-vector-scaled format: FNAME:SCALE"); @@ -2439,7 +2462,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--context-file"}, "FNAME", "file to load context from (use comma-separated values to specify multiple files)", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { std::ifstream file(item, std::ios::binary); if (!file) { throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str())); @@ -2675,9 +2698,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING")); add_opt(common_arg( {"--api-key"}, "KEY", - "API key to use for authentication (default: none)", + "API key to use for authentication, multiple keys can be provided as a comma-separated list (default: none)", [](common_params & params, const std::string & value) { - params.api_keys.push_back(value); + for (const auto & key : parse_csv_row(value)) { + if (!key.empty()) { + params.api_keys.push_back(key); + } + } } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_API_KEY")); add_opt(common_arg( @@ -2691,7 +2718,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex std::string key; while (std::getline(key_file, key)) { if (!key.empty()) { - params.api_keys.push_back(key); + params.api_keys.push_back(key); } } key_file.close(); @@ -2713,7 +2740,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSL_CERT_FILE")); add_opt(common_arg( {"--chat-template-kwargs"}, "STRING", - string_format("sets additional params for the json template parser"), + "sets additional params for the json template parser, must be a valid json object string, e.g. '{\"key1\":\"value1\",\"key2\":\"value2\"}'", [](common_params & params, const std::string & value) { auto parsed = json::parse(value); for (const auto & item : parsed.items()) { diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 50b6bd00e4..6b718e01c3 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -1963,7 +1963,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context & ctx, ggml_tensor * acl_tensor_ptr acl_weight_tensor; // Only check env once. - static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on")); + static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); if (weight_to_nz && is_matmul_weight(weight)) { acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ); } else { diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index e9a21e1b05..6895349b20 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -103,7 +103,7 @@ const ggml_cann_device_info & ggml_cann_info(); void ggml_cann_set_device(int32_t device); int32_t ggml_cann_get_device(); -std::optional get_env(const std::string & name); +std::optional get_env_as_lowercase(const std::string & name); bool parse_bool(const std::string & value); int parse_integer(const std::string & value); diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index acf88db9b7..162d238ae4 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -105,10 +105,10 @@ int32_t ggml_cann_get_device() { } /** - * @brief Get the value of the specified environment variable (name). + * @brief Get the value of the specified environment variable (name) as lowercase. * if not empty, return a std::string object */ -std::optional get_env(const std::string & name) { +std::optional get_env_as_lowercase(const std::string & name) { const char * val = std::getenv(name.c_str()); if (!val) { return std::nullopt; @@ -259,7 +259,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool { * @param device The device ID to associate with this buffer pool. */ explicit ggml_cann_pool_buf_prio(int device) : device(device) { - disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); + disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); } /** @@ -452,7 +452,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool { * @param device The device ID to associate with this buffer pool. */ explicit ggml_cann_pool_buf(int device) : device(device) { - disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); + disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); } /** @@ -764,7 +764,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { * @return A unique pointer to the created CANN pool. */ std::unique_ptr ggml_backend_cann_context::new_pool_for_device(int device) { - std::string mem_pool_type = get_env("GGML_CANN_MEM_POOL").value_or(""); + std::string mem_pool_type = get_env_as_lowercase("GGML_CANN_MEM_POOL").value_or(""); if (mem_pool_type == "prio") { GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device); @@ -1217,7 +1217,7 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer, // Why aclrtSynchronizeDevice? // Only check env once. - static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on")); + static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); if (!need_transform(tensor->type)) { ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE)); if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) { @@ -1442,7 +1442,7 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_t int64_t ne0 = tensor->ne[0]; // Only check env once. - static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on")); + static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); // last line must bigger than 32, because every single op deal at // least 32 bytes. @@ -2136,7 +2136,7 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx #endif // USE_ACL_GRAPH // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph. // With the use of CANN graphs, the execution will be performed by the graph launch. - static bool opt_fusion = parse_bool(get_env("GGML_CANN_OPERATOR_FUSION").value_or("")); + static bool opt_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or("")); if (!use_cann_graph || cann_graph_capture_required) { for (int i = 0; i < cgraph->n_nodes; i++) { @@ -2201,7 +2201,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, #ifdef USE_ACL_GRAPH bool use_cann_graph = true; - static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or("")); + static bool prefill_use_graph = parse_bool(get_env_as_lowercase("GGML_CANN_PREFILL_USE_GRAPH").value_or("")); if (!prefill_use_graph) { // Do not use acl_graph for prefill. for (int i = 0; i < cgraph->n_nodes; i++) { diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 995b774c20..9516d8ec8f 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1036,7 +1036,7 @@ struct ggml_tensor_extra_gpu { #define USE_CUDA_GRAPH #endif -struct ggml_graph_node_properties { +struct ggml_cuda_graph_node_properties { void * node_address; ggml_op node_op; int64_t ne[GGML_MAX_DIMS]; @@ -1061,11 +1061,25 @@ struct ggml_cuda_graph { std::vector nodes; bool disable_due_to_gpu_arch = false; bool disable_due_to_too_many_updates = false; - bool disable_due_to_failed_graph_capture = false; int number_consecutive_updates = 0; - bool cuda_graphs_enabled = false; - std::vector ggml_graph_properties; - std::vector extraneous_srcs_properties; + std::vector props; + + void record_update(bool use_graph, bool update_required) { + if (use_graph && update_required) { + number_consecutive_updates++; + } else { + number_consecutive_updates = 0; + } + if (number_consecutive_updates >= 4) { + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); + disable_due_to_too_many_updates = true; + } + } + + bool is_enabled() const { + static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); + return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates); + } #endif }; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 75269170c3..bac69cdd1c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2853,9 +2853,9 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { } #ifdef USE_CUDA_GRAPH -static bool check_node_graph_compatibility(ggml_cgraph * cgraph, - bool use_cuda_graph) { +static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { + bool use_cuda_graph = true; // Loop over nodes in GGML graph to obtain info needed for CUDA graph const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected"; @@ -2915,41 +2915,41 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph, return use_cuda_graph; } -static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { - graph_node_properties->node_address = node->data; - graph_node_properties->node_op = node->op; +static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) { + props->node_address = node->data; + props->node_op = node->op; for (int i = 0; i < GGML_MAX_DIMS; i++) { - graph_node_properties->ne[i] = node->ne[i]; - graph_node_properties->nb[i] = node->nb[i]; + props->ne[i] = node->ne[i]; + props->nb[i] = node->nb[i]; } for (int i = 0; i < GGML_MAX_SRC; i++) { - graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; + props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; } - memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS); + memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS); } -static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { - if (node->data != graph_node_properties->node_address && +static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) { + if (node->data != props->node_address && node->op != GGML_OP_VIEW) { return false; } - if (node->op != graph_node_properties->node_op) { + if (node->op != props->node_op) { return false; } for (int i = 0; i < GGML_MAX_DIMS; i++) { - if (node->ne[i] != graph_node_properties->ne[i]) { + if (node->ne[i] != props->ne[i]) { return false; } - if (node->nb[i] != graph_node_properties->nb[i]) { + if (node->nb[i] != props->nb[i]) { return false; } } for (int i = 0; i < GGML_MAX_SRC; i++) { if (node->src[i] && - node->src[i]->data != graph_node_properties->src_address[i] && + node->src[i]->data != props->src_address[i] && node->op != GGML_OP_VIEW ) { return false; @@ -2957,56 +2957,55 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra } if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) && - memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { + memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { return false; } return true; } -static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { +static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { - bool cuda_graph_update_required = false; + bool res = false; if (cuda_ctx->cuda_graph->instance == nullptr) { - cuda_graph_update_required = true; + res = true; } // Check if the graph size has changed - if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) { - cuda_graph_update_required = true; - cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes + cgraph->n_leafs); + if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) { + res = true; + cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs); } // Loop over nodes in GGML graph to determine if CUDA graph update is required // and store properties to allow this comparison for the next token for (int i = 0; i < cgraph->n_nodes; i++) { - bool has_matching_properties = true; - - if (!cuda_graph_update_required) { - has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + bool props_match = true; + if (!res) { + props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]); } - if (!has_matching_properties) { - cuda_graph_update_required = true; + if (!props_match) { + res = true; } - set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]); } for (int i = 0; i < cgraph->n_leafs; i++) { - bool has_matching_properties = true; - if (!cuda_graph_update_required) { - has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->leafs[i], &cuda_ctx->cuda_graph->ggml_graph_properties[cgraph->n_nodes + i]); + bool props_match= true; + if (!res) { + props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]); } - if (!has_matching_properties) { - cuda_graph_update_required = true; + if (!props_match) { + res = true; } - set_ggml_graph_node_properties(cgraph->leafs[i], &cuda_ctx->cuda_graph->ggml_graph_properties[cgraph->n_nodes + i]); + ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]); } - return cuda_graph_update_required; + return res; } -static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { +static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) { #if CUDART_VERSION >= 12000 cudaGraphExecUpdateResultInfo result_info; @@ -3237,10 +3236,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, return false; } -static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, - bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { +static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) { + bool graph_evaluated_or_captured = false; + // flag used to determine whether it is an integrated_gpu - const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; + const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context(); bool is_concurrent_event_active = false; @@ -3710,7 +3710,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); } if (cuda_graph_update_required) { // Update graph executable - update_cuda_graph_executable(cuda_ctx); + ggml_cuda_graph_update_executable(cuda_ctx); } // Launch graph CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); @@ -3720,43 +3720,25 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } } -static bool ggml_cuda_set_cuda_graph_enabled(ggml_backend_cuda_context * cuda_ctx) { +static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) { #ifdef USE_CUDA_GRAPH - static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); - // Objects required for CUDA Graph if (cuda_ctx->cuda_graph == nullptr) { cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); } - bool use_cuda_graph = true; - if (cuda_ctx->cuda_graph->graph == nullptr) { if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) { cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; -#ifndef NDEBUG GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); -#endif } } - // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly, - // or previous graph capture failure. - // Also disable for multi-gpu for now. TO DO investigate - if (disable_cuda_graphs_due_to_env - || cuda_ctx->cuda_graph->disable_due_to_gpu_arch - || cuda_ctx->cuda_graph->disable_due_to_too_many_updates - || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) { - use_cuda_graph = false; - } - - cuda_ctx->cuda_graph->cuda_graphs_enabled = use_cuda_graph; + return cuda_ctx->cuda_graph->is_enabled(); #else - bool use_cuda_graph = false; + return false; #endif // USE_CUDA_GRAPH - - return use_cuda_graph; } static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { @@ -3767,30 +3749,14 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, bool use_cuda_graph = false; bool cuda_graph_update_required = false; - // graph_optimize calls set_cuda_graph_enabled, in-case it not called (i.e. graph_compute is directly called) - // we call it here instead. #ifdef USE_CUDA_GRAPH - use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx); + use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx); - if (use_cuda_graph) { - cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); + if (cuda_ctx->cuda_graph->is_enabled()) { + cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph); + use_cuda_graph = ggml_cuda_graph_check_compability(cgraph); - use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph); - - // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. - if (use_cuda_graph && cuda_graph_update_required) { - cuda_ctx->cuda_graph->number_consecutive_updates++; - } else { - cuda_ctx->cuda_graph->number_consecutive_updates = 0; - } - - if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) { - cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true; - cuda_ctx->cuda_graph->cuda_graphs_enabled = false; -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); -#endif - } + cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required); } #endif // USE_CUDA_GRAPH @@ -3804,9 +3770,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); } - bool graph_evaluated_or_captured = false; - - evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); + ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required); return GGML_STATUS_SUCCESS; } @@ -3839,7 +3803,7 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; - const bool use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx); + const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx); static bool enable_graph_optimization = [] { const char * env = getenv("GGML_CUDA_GRAPH_OPT"); diff --git a/ggml/src/ggml-cuda/mean.cu b/ggml/src/ggml-cuda/mean.cu index 691d8dcb14..60542fc19d 100644 --- a/ggml/src/ggml-cuda/mean.cu +++ b/ggml/src/ggml-cuda/mean.cu @@ -34,13 +34,11 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { // CUDA_GRAPHS_DISABLED ((ncols > 65536) && ((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || - ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates || - ctx.cuda_graph->disable_due_to_failed_graph_capture)) || + ctx.cuda_graph->is_enabled())) || // CUDA_GRAPHS ENABLED ((ncols > 32768) && !((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || - ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates || - ctx.cuda_graph->disable_due_to_failed_graph_capture))) { + ctx.cuda_graph->is_enabled()))) { #else (ncols > 65536)) { #endif // USE_CUDA_GRAPH diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 85692d4543..ceb95758d2 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -333,6 +333,28 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t } if (amd_wmma_available(cc)) { + // RDNA 4 is consistently worse on rocblas + // https://github.com/ggml-org/llama.cpp/pull/18537#issuecomment-3706422301 + if (GGML_CUDA_CC_IS_RDNA3(cc)) { + // High expert counts almost always better on MMQ + // due to a large amount of graph splits + // https://github.com/ggml-org/llama.cpp/pull/18202 + if (n_experts >= 64) { + return true; + } + + switch (type) { + // These quants are really bad on MMQ + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q6_K: + // These quants are usually worse but not always + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + return ne11 <= 128; + default: + return true; + } + } return true; } diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index 6b424381df..c1d4e2bc8d 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -114,7 +114,7 @@ __global__ void __launch_bounds__(splitD, 1) #endif // __clang__ // assumes as many threads as d_state -template +template __global__ void __launch_bounds__(d_state, 1) ssm_scan_f32_group( const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, @@ -125,20 +125,25 @@ __global__ void __launch_bounds__(d_state, 1) const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) { - const int head_idx = (blockIdx.x * splitH) / d_head; - const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float); - const int seq_idx = blockIdx.y; + const int warp = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + const int warp_idx = blockIdx.x * c_factor + warp; + + const int head_idx = warp_idx / d_head; + const int head_off = (warp_idx % d_head) * sizeof(float); + const int seq_idx = blockIdx.y; const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float); - const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); - const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float)); - const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float)); - const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1); - const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off)); - const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off)); - float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH; - float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + // TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase + const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float))); + const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float)); + const float * A_warp = (const float *) ((const char *) src3 + head_idx * src3_nb1); + const float * B_warp = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off)); + const float * C_warp = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off)); + float * y_warp = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx; + float * s_warp = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); // strides across n_seq_tokens const int stride_x = src1_nb2 / sizeof(float); @@ -147,80 +152,42 @@ __global__ void __launch_bounds__(d_state, 1) const int stride_C = src5_nb2 / sizeof(float); const int stride_y = n_head * d_head; - float state[splitH]; - // for the parallel accumulation - __shared__ float stateC[splitH * d_state]; + float state[c_factor]; + float state_sum = 0.0f; #pragma unroll - for (int j = 0; j < splitH; j++) { - state[j] = s0_block[j * d_state + threadIdx.x]; + for (int j = 0; j < c_factor; j++) { + state[j] = s0_warp[WARP_SIZE * j + lane]; } for (int64_t i = 0; i < n_tok; i++) { - // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements - // TODO: only calculate B and C once per head group - // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here. - float dt_soft_plus = dt_block[i * stride_dt]; - if (dt_soft_plus <= 20.0f) { - dt_soft_plus = log1pf(expf(dt_soft_plus)); - } - const float dA = expf(dt_soft_plus * A_block[0]); - const float B = B_block[i * stride_B + threadIdx.x]; - const float C = C_block[i * stride_C + threadIdx.x]; + // NOTE: dt_soft_plus, dA and x_dt have the same value for a warp here. + // Recalculation is intentional; sharing via shuffles/smem proved slower due to sync overhead. + const float dt_soft_plus = (dt_warp[i * stride_dt] <= 20.0f ? log1pf(expf(dt_warp[i * stride_dt])) : dt_warp[i * stride_dt]); - // across d_head + state_sum = 0.0f; + const float dA = expf(dt_soft_plus * A_warp[0]); + const float x_dt = x_warp[i * stride_x] * dt_soft_plus; #pragma unroll - for (int j = 0; j < splitH; j++) { - const float x_dt = x_block[i * stride_x + j] * dt_soft_plus; - - state[j] = (state[j] * dA) + (B * x_dt); - - stateC[j * d_state + threadIdx.x] = state[j] * C; + for (int j = 0; j < c_factor; j++) { + const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane]; + const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane]; + state[j] = (state[j] * dA) + (B_val * x_dt); + state_sum += state[j] * C_val; } - __syncthreads(); + // parallel accumulation for output + state_sum = warp_reduce_sum(state_sum); - // parallel accumulation for stateC - // TODO: simplify - { - static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2"); - static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2"); - - // reduce until w matches the warp size - // TODO: does this work even when the physical warp size is 64? -#pragma unroll - for (int w = d_state; w > WARP_SIZE; w >>= 1) { - // (assuming there are d_state threads) -#pragma unroll - for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) { - // TODO: check for bank conflicts - const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1)); - stateC[k] += stateC[k + (w >> 1)]; - - } - __syncthreads(); - } - - static_assert(splitH >= d_state / WARP_SIZE); - -#pragma unroll - for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) { - float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)]; - y = warp_reduce_sum(y); - - // store the above accumulations - if (threadIdx.x % WARP_SIZE == 0) { - const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE); - y_block[i * stride_y + k] = y; - } - } + if (lane == 0) { + y_warp[i * stride_y] = state_sum; } } // write back the state #pragma unroll - for (int j = 0; j < splitH; j++) { - s_block[j * d_state + threadIdx.x] = state[j]; + for (int j = 0; j < c_factor; j++) { + s_warp[WARP_SIZE * j + lane] = state[j]; } } @@ -231,27 +198,24 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim, const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq, cudaStream_t stream) { - const int threads = 128; // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition! if (src3_nb1 == sizeof(float)) { // Mamba-2 if (d_state == 128) { - GGML_ASSERT(d_state % threads == 0); - // NOTE: can be any power of two between 4 and 64 - const int splitH = 16; - GGML_ASSERT(head_dim % splitH == 0); - const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1); - ssm_scan_f32_group<16, 128><<>>( + constexpr int threads = 128; + constexpr int num_warps = threads/WARP_SIZE; + + const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); + ssm_scan_f32_group<128/WARP_SIZE, 128><<>>( src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); } else if (d_state == 256) { // Falcon-H1 - const int threads = 256; - // NOTE: can be any power of two between 8 and 64 - const int splitH = 16; - GGML_ASSERT(head_dim % splitH == 0); - const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1); - ssm_scan_f32_group<16, 256><<>>( + constexpr int threads = 256; + constexpr int num_warps = threads/WARP_SIZE; + + const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); + ssm_scan_f32_group<256/WARP_SIZE, 256><<>>( src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); @@ -260,6 +224,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa } } else { // Mamba-1 + constexpr int threads = 128; GGML_ASSERT(n_head % threads == 0); GGML_ASSERT(head_dim == 1); GGML_ASSERT(n_group == 1); diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 13b96d61f8..365a24b496 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -1773,6 +1773,37 @@ static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_ return true; } +static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + const struct ggml_tensor * src2 = op->src[2]; + const struct ggml_tensor * src3 = op->src[3]; + const struct ggml_tensor * src4 = op->src[4]; + const struct ggml_tensor * dst = op; + + // Check for F16 support only as requested + if ((src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_F32) || src1->type != GGML_TYPE_F16 || src2->type != GGML_TYPE_F16) { + return false; + } + + if (src3 && src3->type != GGML_TYPE_F16) { // mask + return false; + } + + if (src4 && src4->type != GGML_TYPE_F32) { // sinks + return false; + } + + // For now we support F32 or F16 output as htp backend often converts output on the fly if needed, + // but the op implementation writes to F16 or F32. + // Let's assume dst can be F32 or F16. + if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) { + return false; + } + + return opt_experimental; +} + static bool hex_supported_src0_type(ggml_type t) { return t == GGML_TYPE_F32; } @@ -1815,12 +1846,11 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + if (dst->type != GGML_TYPE_F32) { return false; } - // TODO: add support for non-cont tensors - if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) { return false; } @@ -1836,7 +1866,6 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s return false; // typically the lm-head which would be too large for VTCM } - // if ((src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3])) return false; if ((src1->ne[2] != 1 || src1->ne[3] != 1)) { return false; } @@ -1885,21 +1914,10 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session } break; - case GGML_TYPE_F16: - if (!opt_experimental) { - return false; - } - break; - default: return false; } - // TODO: add support for non-cont tensors - if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { - return false; - } - return true; } @@ -2060,6 +2078,46 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s return true; } +static bool ggml_hexagon_supported_set_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; // values + const struct ggml_tensor * src1 = op->src[1]; // indices + const struct ggml_tensor * dst = op; + + if (src0->type != GGML_TYPE_F32) { + return false; + } + + if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) { + return false; + } + + if (dst->type != GGML_TYPE_F16) { + return false; + } + + return true; +} + +static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; // values + const struct ggml_tensor * src1 = op->src[1]; // indices + const struct ggml_tensor * dst = op; + + if (src0->type != GGML_TYPE_F32) { + return false; + } + + if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) { + return false; + } + + if (dst->type != GGML_TYPE_F32) { + return false; + } + + return true; +} + static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const int32_t * op_params = &op->op_params[0]; @@ -2154,6 +2212,11 @@ static size_t htp_req_buff_init(htp_tensor *h, dspqueue_buffer * d, const ggml_t d->offset = (uint8_t *) t->data - buf->base; d->size = ggml_nbytes(t); + if (!d->size) { + // Some requests contain srcs where ggml_nbytes() returns 0 but the rest of the op is non-empty + d->size = 64; + } + switch (type) { case DSPQBUF_TYPE_DSP_WRITE_CPU_READ: // Flush CPU @@ -2239,6 +2302,17 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu return n_bufs; } +static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_GET_ROWS; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + template static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { switch (t->op) { @@ -2266,6 +2340,17 @@ static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * return n_bufs; } +static inline size_t init_set_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_SET_ROWS; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); @@ -2277,6 +2362,11 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf supported = true; break; + case GGML_OP_SCALE: + req->op = HTP_OP_SCALE; + supported = true; + break; + case GGML_OP_UNARY: if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) { req->op = HTP_OP_UNARY_SILU; @@ -2331,6 +2421,21 @@ static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs return n_bufs; } +static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); + req->op = HTP_OP_FLASH_ATTN_EXT; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src3, &bufs[n_bufs], t->src[3], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src4, &bufs[n_bufs], t->src[4], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { auto sess = static_cast(backend->context); return sess->name.c_str(); @@ -2417,6 +2522,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op>(sess, node, flags); break; case GGML_OP_RMS_NORM: + case GGML_OP_SCALE: ggml_hexagon_dispatch_op(sess, node, flags); break; case GGML_OP_UNARY: @@ -2439,6 +2545,18 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_FLASH_ATTN_EXT: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + + case GGML_OP_SET_ROWS: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + + case GGML_OP_GET_ROWS: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); } @@ -2778,6 +2896,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons break; case GGML_OP_RMS_NORM: + case GGML_OP_SCALE: supp = ggml_hexagon_supported_unary(sess, op); break; @@ -2805,6 +2924,18 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_rope(sess, op); break; + case GGML_OP_FLASH_ATTN_EXT: + supp = ggml_hexagon_supported_flash_attn_ext(sess, op); + break; + + case GGML_OP_SET_ROWS: + supp = ggml_hexagon_supported_set_rows(sess, op); + break; + + case GGML_OP_GET_ROWS: + supp = ggml_hexagon_supported_get_rows(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 2cf8aaa42a..6a34a215fa 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -28,6 +28,9 @@ add_library(${HTP_LIB} SHARED softmax-ops.c act-ops.c rope-ops.c + flash-attn-ops.c + set-rows-ops.c + get-rows-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c new file mode 100644 index 0000000000..04a7b843ce --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -0,0 +1,566 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#ifdef HTP_DEBUG +# define FARF_HIGH 1 +#endif +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-dma.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" +#include "ops-utils.h" + +// Dot product of FP32 and FP16 vectors, accumulating to float +static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) { + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 + const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector rsum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements + HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements + HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + + // Load x (fp16) + HVX_Vector x_hf = vx[i]; + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements + HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements + HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + + // Load x (fp16) + HVX_Vector x_hf = vx[i]; + + // Zero-out unused elements + // Note that we need to clear both x and y because they may contain NANs + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + x_hf = Q6_V_vand_QV(bmask, x_hf); + y_hf = Q6_V_vand_QV(bmask, y_hf); + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s)); + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + + hvx_vec_store_u(r, 4, rsum); +} + +// Dot product of two F16 vectors, accumulating to float +static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) { + const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16 + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector rsum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_Vector y_hf = vy[i]; + HVX_Vector x_hf = vx[i]; + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + if (nloe) { + HVX_Vector y_hf = vy[i]; + + // Load x (fp16) and zero-out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]); + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s)); + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + hvx_vec_store_u(r, 4, rsum); +} + +// MAD: y (F32) += x (F16) * v (float) +static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) { + const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x; + HVX_Vector * restrict ptr_y = (HVX_Vector *) y; + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + HVX_Vector S = hvx_vec_splat_fp16(s); + + uint32_t i = 0; + #pragma unroll(4) + for (i = 0; i < nvec; ++i) { + // Multiply x * s -> pair of F32 vectors + HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S); + ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2])); + ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1])); + } + + if (nloe) { + HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S); + + HVX_Vector xs = Q6_V_lo_W(xs_p); + i = 2 * i; // index for ptr_y + + if (nloe >= 32) { + ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); + nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p); + } + + if (nloe) { + HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); + hvx_vec_store_u(&ptr_y[i], nloe * 4, xy); + } + } +} + +#define FLASH_ATTN_BLOCK_SIZE 128 + +static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) { + const struct htp_tensor * q = &octx->src0; + const struct htp_tensor * k = &octx->src1; + const struct htp_tensor * v = &octx->src2; + const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL; + const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL; + struct htp_tensor * dst = &octx->dst; + + const uint32_t neq0 = q->ne[0]; + const uint32_t neq1 = q->ne[1]; + const uint32_t neq2 = q->ne[2]; + const uint32_t neq3 = q->ne[3]; + + const uint32_t nek0 = k->ne[0]; + const uint32_t nek1 = k->ne[1]; + const uint32_t nek2 = k->ne[2]; + const uint32_t nek3 = k->ne[3]; + + const uint32_t nev0 = v->ne[0]; + const uint32_t nev1 = v->ne[1]; + const uint32_t nev2 = v->ne[2]; + const uint32_t nev3 = v->ne[3]; + + const uint32_t nbq1 = q->nb[1]; + const uint32_t nbq2 = q->nb[2]; + const uint32_t nbq3 = q->nb[3]; + + const uint32_t nbk1 = k->nb[1]; + const uint32_t nbk2 = k->nb[2]; + const uint32_t nbk3 = k->nb[3]; + + const uint32_t nbv1 = v->nb[1]; + const uint32_t nbv2 = v->nb[2]; + const uint32_t nbv3 = v->nb[3]; + + const uint32_t ne1 = dst->ne[1]; + const uint32_t ne2 = dst->ne[2]; + const uint32_t ne3 = dst->ne[3]; + + const uint32_t nb1 = dst->nb[1]; + const uint32_t nb2 = dst->nb[2]; + const uint32_t nb3 = dst->nb[3]; + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (float *) octx->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + // total rows in q + const uint32_t nr = neq1*neq2*neq3; + + const uint32_t dr = (nr + nth - 1) / nth; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = MIN(ir0 + dr, nr); + + if (ir0 >= ir1) return; + + dma_queue * dma = octx->ctx->dma[ith]; + + const uint32_t DK = nek0; + const uint32_t DV = nev0; + + const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2); + const size_t size_q_row_padded = htp_round_up(size_q_row, 128); + + const size_t size_k_row = DK * sizeof(__fp16); + const size_t size_v_row = DV * sizeof(__fp16); + const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask + + const size_t size_k_row_padded = htp_round_up(size_k_row, 128); + const size_t size_v_row_padded = htp_round_up(size_v_row, 128); + + const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; + const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; + const size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); + + // Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator + uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith; + uint8_t * spad_k = octx->src1_spad.data + octx->src1_spad.size_per_thread * ith; + uint8_t * spad_v = octx->src2_spad.data + octx->src2_spad.size_per_thread * ith; + uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith; + uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith; + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + for (uint32_t ir = ir0; ir < ir1; ++ir) { + const uint32_t iq3 = fastdiv(ir, &octx->src0_div21); + const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1); + const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1); + + const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3); + const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2); + + const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3); + const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2); + + // Fetch Q row + const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3); + dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1); + + const uint32_t h = iq2; // head index + const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f; + + float S = 0.0f; // sum + float M = -INFINITY; // maximum KQ value + + // Clear accumulator + float * VKQ32 = (float *) spad_a; + memset(VKQ32, 0, DV * sizeof(float)); + + const __fp16 * mp_base = NULL; + if (mask) { + const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2); + const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3); + mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]); + } + + const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE; + + // Prefetch first two blocks + for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) { + const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; + const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start); + + // K + const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3); + uint8_t * k_dst = spad_k + (ib % 2) * size_k_block; + dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size); + + // V + const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3); + uint8_t * v_dst = spad_v + (ib % 2) * size_v_block; + dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size); + + // Mask + if (mask) { + const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start); + uint8_t * m_dst = spad_m + (ib % 2) * size_m_block; + // Mask is 1D contiguous for this row + dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1); + } + } + + const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; + + for (uint32_t ib = 0; ib < n_blocks; ++ib) { + const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; + const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start); + + // Wait for DMA + uint8_t * k_base = dma_queue_pop(dma).dst; // K + uint8_t * v_base = dma_queue_pop(dma).dst; // V + __fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M + + // Inner loop processing the block from VTCM + uint32_t ic = 0; + + // Process in blocks of 32 (VLEN_FP32) + for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) { + // 1. Compute scores + float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32]; + for (int j = 0; j < VLEN_FP32; ++j) { + const uint32_t cur_ic = ic + j; + const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded; + if (q->type == HTP_TYPE_F32) { + hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale); + } else { + hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale); + } + } + + HVX_Vector scores = *(HVX_Vector *) scores_arr; + + // 2. Softcap + if (logit_softcap != 0.0f) { + scores = hvx_vec_tanh_fp32(scores); + scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_fp32(logit_softcap)); + scores = Q6_Vsf_equals_Vqf32(scores); + } + + // 3. Mask + if (mask) { + const __fp16 * mp = m_base + ic; + HVX_Vector m_vals_fp16 = *(const HVX_UVector *) mp; + + HVX_Vector one_fp16 = Q6_Vh_vsplat_R(0x3c00); + HVX_VectorPair m_vals_fp32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_fp16), one_fp16); + + HVX_Vector m_vals_fp32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_fp32_pair)); + + HVX_Vector slope_vec = hvx_vec_splat_fp32(slope); + HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_fp32, slope_vec); + scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val)); + scores = Q6_Vsf_equals_Vqf32(scores); + } + + // 4. Online Softmax Update + HVX_Vector v_max = hvx_vec_reduce_max_fp32(scores); + float m_block = hvx_vec_get_fp32(v_max); + + float M_old = M; + float M_new = (m_block > M) ? m_block : M; + M = M_new; + + float ms = expf(M_old - M_new); + + hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + S = S * ms; + + HVX_Vector M_new_vec = hvx_vec_splat_fp32(M_new); + HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec); + HVX_Vector P = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(scores_shifted)); + + HVX_Vector p_sum_vec = hvx_vec_fp32_reduce_sum(P); + float p_sum = hvx_vec_get_fp32(p_sum_vec); + S += p_sum; + + // 5. Accumulate V + float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32]; + *(HVX_Vector*)p_arr = P; + + for (int j = 0; j < VLEN_FP32; ++j) { + const uint32_t cur_ic = ic + j; + const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded; + hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]); + } + } + + // Leftover + for (; ic < current_block_size; ++ic) { + float s_val; + const uint8_t * k_ptr = k_base + ic * size_k_row_padded; + + if (q->type == HTP_TYPE_F32) { + hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); + } else { + hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); + } + + if (logit_softcap != 0.0f) { + s_val = logit_softcap * tanhf(s_val); + } + + if (mask) { + const float m_val = m_base[ic]; + s_val += slope * m_val; + } + + const float Mold = M; + float ms = 1.0f; + float vs = 1.0f; + + if (s_val > M) { + M = s_val; + ms = expf(Mold - M); + hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + } else { + vs = expf(s_val - M); + } + + const uint8_t * v_ptr = v_base + ic * size_v_row_padded; + + hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs); + + S = S * ms + vs; + } + + // Issue DMA for next+1 block (if exists) + if (ib + 2 < n_blocks) { + const uint32_t next_ib = ib + 2; + const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE; + const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start); + + // K + const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3); + dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size); + + // V + const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3); + dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size); + + // Mask + if (mask) { + const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start); + dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1); + } + } + } + + // sinks + if (sinks) { + const float s = ((float *)((char *) sinks->data))[h]; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + ms = expf(M - s); + hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + } else { + vs = expf(s - M); + } + + S = S * ms + vs; + } + + const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; + hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, S_inv); + + // Store result + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // dst is permuted + uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1; + + if (dst->type == HTP_TYPE_F32) { + hvx_copy_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV); + } else if (dst->type == HTP_TYPE_F16) { + hvx_copy_fp16_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV); + } + } +} + +static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + flash_attn_ext_f16_thread(octx, i, n); +} + +int op_flash_attn_ext(struct htp_ops_context * octx) { + const struct htp_tensor * q = &octx->src0; + const struct htp_tensor * k = &octx->src1; + const struct htp_tensor * v = &octx->src2; + const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL; + struct htp_tensor * dst = &octx->dst; + + // Check support + if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || + k->type != HTP_TYPE_F16 || + v->type != HTP_TYPE_F16) { + return HTP_STATUS_NO_SUPPORT; + } + + octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]); + octx->src0_div1 = init_fastdiv_values(q->ne[1]); + + octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]); + octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]); + octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]); + octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]); + + if (mask) { + octx->src3_div2 = init_fastdiv_values(mask->ne[2]); + octx->src3_div3 = init_fastdiv_values(mask->ne[3]); + } + + size_t size_q_row_padded = htp_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128); + size_t size_k_row_padded = htp_round_up(k->ne[0] * sizeof(__fp16), 128); + size_t size_v_row_padded = htp_round_up(v->ne[0] * sizeof(__fp16), 128); + + size_t size_q_block = size_q_row_padded * 1; // single row for now + size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; + size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; + size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); + + size_t size_vkq_acc = htp_round_up(v->ne[0] * sizeof(float), 128); // VKQ32 + + octx->src0_spad.size_per_thread = size_q_block * 1; + octx->src1_spad.size_per_thread = size_k_block * 2; + octx->src2_spad.size_per_thread = size_v_block * 2; + octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0; + octx->dst_spad.size_per_thread = size_vkq_acc; + + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; + octx->src2_spad.size = octx->src2_spad.size_per_thread * octx->n_threads; + octx->src3_spad.size = octx->src3_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + + size_t total_spad = octx->src0_spad.size + octx->src1_spad.size + octx->src2_spad.size + octx->src3_spad.size + octx->dst_spad.size; + + if (octx->ctx->vtcm_size < total_spad) { + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size; + octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size; + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads); + } + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/ggml/src/ggml-hexagon/htp/get-rows-ops.c new file mode 100644 index 0000000000..54321421eb --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/get-rows-ops.c @@ -0,0 +1,112 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#ifdef HTP_DEBUG +# define FARF_HIGH 1 +#endif +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" +#include "ops-utils.h" + +#define get_rows_preamble \ + const uint32_t ne00 = octx->src0.ne[0]; \ + const uint32_t ne01 = octx->src0.ne[1]; \ + const uint32_t ne02 = octx->src0.ne[2]; \ + const uint32_t ne03 = octx->src0.ne[3]; \ + \ + const uint32_t ne10 = octx->src1.ne[0]; \ + const uint32_t ne11 = octx->src1.ne[1]; \ + const uint32_t ne12 = octx->src1.ne[2]; \ + \ + const uint32_t nb01 = octx->src0.nb[1]; \ + const uint32_t nb02 = octx->src0.nb[2]; \ + const uint32_t nb03 = octx->src0.nb[3]; \ + \ + const uint32_t nb10 = octx->src1.nb[0]; \ + const uint32_t nb11 = octx->src1.nb[1]; \ + const uint32_t nb12 = octx->src1.nb[2]; \ + \ + const uint32_t nb1 = octx->dst.nb[1]; \ + const uint32_t nb2 = octx->dst.nb[2]; \ + const uint32_t nb3 = octx->dst.nb[3]; \ + \ + const uint32_t nr = ne10 * ne11 * ne12; + +static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) { + get_rows_preamble; + + // parallelize by src1 elements (which correspond to dst rows) + const uint32_t dr = octx->src1_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; + + const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + + for (uint32_t i = ir0; i < ir1; ++i) { + const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11); + const uint32_t rem = i - i12 * ne11 * ne10; + const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10); + const uint32_t i10 = rem - i11 * ne10; + + const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + + uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; + + if (i01 >= ne01) { + // invalid index, skip for now to avoid crash + continue; + } + + const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03; + const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3; + hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); + } + + return HTP_STATUS_OK; +} + +static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) { + get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i); +} + +int op_get_rows(struct htp_ops_context * octx) { + get_rows_preamble; + + if (octx->src0.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->dst.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]); + octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]); + + const uint32_t n_jobs = MIN(nr, octx->n_threads); + octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + + worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs); + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 5c3d217f1c..4bd0ea7a36 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -11,11 +11,6 @@ #define HTP_MAX_NTHREADS 10 -// FIXME: move these into matmul-ops -#define HTP_SPAD_SRC0_NROWS 16 -#define HTP_SPAD_SRC1_NROWS 16 -#define HTP_SPAD_DST_NROWS 2 - // Main context for htp DSP backend struct htp_context { dspqueue_t queue; diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index a61652304a..846d061784 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -36,6 +36,8 @@ enum htp_data_type { HTP_TYPE_F16 = 1, HTP_TYPE_Q4_0 = 2, HTP_TYPE_Q8_0 = 8, + HTP_TYPE_I32 = 26, + HTP_TYPE_I64 = 27, HTP_TYPE_MXFP4 = 39, HTP_TYPE_COUNT }; @@ -57,6 +59,10 @@ enum htp_op { HTP_OP_SOFTMAX = 11, HTP_OP_ADD_ID = 12, HTP_OP_ROPE = 13, + HTP_OP_FLASH_ATTN_EXT = 14, + HTP_OP_SET_ROWS = 15, + HTP_OP_SCALE = 16, + HTP_OP_GET_ROWS = 17, INVALID }; @@ -137,6 +143,8 @@ struct htp_general_req { struct htp_tensor src0; // Input0 tensor struct htp_tensor src1; // Input1 tensor struct htp_tensor src2; // Input2 tensor + struct htp_tensor src3; // Input3 tensor + struct htp_tensor src4; // Input4 tensor struct htp_tensor dst; // Output tensor // should be multiple of 64 bytes (cacheline) @@ -152,6 +160,6 @@ struct htp_general_rsp { }; #define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req) -#define HTP_MAX_PACKET_BUFFERS 4 +#define HTP_MAX_PACKET_BUFFERS 8 #endif /* HTP_MSG_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index e87657436f..7c828ae636 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -13,6 +13,7 @@ struct htp_spad { uint8_t * data; + size_t stride; size_t size; size_t size_per_thread; }; @@ -26,11 +27,14 @@ struct htp_ops_context { struct htp_tensor src0; struct htp_tensor src1; struct htp_tensor src2; + struct htp_tensor src3; + struct htp_tensor src4; struct htp_tensor dst; struct htp_spad src0_spad; struct htp_spad src1_spad; struct htp_spad src2_spad; + struct htp_spad src3_spad; struct htp_spad dst_spad; worker_pool_context_t * wpool; // worker pool @@ -49,6 +53,27 @@ struct htp_ops_context { struct fastdiv_values src1_div3; // fastdiv values for ne3 struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1 + struct fastdiv_values src3_div1; // fastdiv values for ne1 + struct fastdiv_values src3_div2; // fastdiv values for ne2 + struct fastdiv_values src3_div3; // fastdiv values for ne3 + struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1 + + struct fastdiv_values broadcast_rk2; + struct fastdiv_values broadcast_rk3; + struct fastdiv_values broadcast_rv2; + struct fastdiv_values broadcast_rv3; + + struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1 + struct fastdiv_values mm_div_ne1; // fastdiv values for ne1 + struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02 + struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03 + + struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12 + struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11 + + struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10 + struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11 + uint32_t flags; }; @@ -60,5 +85,8 @@ int op_activations(struct htp_ops_context * octx); int op_softmax(struct htp_ops_context * octx); int op_add_id(struct htp_ops_context * octx); int op_rope(struct htp_ops_context * octx); +int op_flash_attn_ext(struct htp_ops_context * octx); +int op_set_rows(struct htp_ops_context * octx); +int op_get_rows(struct htp_ops_context * octx); #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.c b/ggml/src/ggml-hexagon/htp/hvx-utils.c index f9e02ab67e..29d73b8622 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.c +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.c @@ -848,55 +848,6 @@ float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems) { return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v)); } -void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale) { - int left_over = num_elems & (VLEN_FP32 - 1); - int num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_scale_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_scale_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - HVX_Vector scale_vec = hvx_vec_splat_fp32(scale); - - if (0 == unaligned_loop) { - HVX_Vector * vec_in1 = (HVX_Vector *) src; - HVX_Vector * vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, scale_vec); - *vec_out++ = Q6_Vsf_equals_Vqf32(v); - } - } else { - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); - - HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec); - - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out); - } - } - - if (left_over > 0) { - const float * srcf = (const float *) src + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_Vector in = *(HVX_UVector *) srcf; - - HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec); - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out)); - } -} - float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) { int left_over = num_elems & (VLEN_FP32 - 1); int num_elems_whole = num_elems - left_over; @@ -1065,3 +1016,5 @@ void hvx_clamp_scalar_f32(const uint8_t * restrict src, hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, in_vec); } } + + diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index d2d5d23636..22876e6dba 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -41,15 +41,24 @@ static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in) } #endif -static inline HVX_Vector hvx_vec_splat_fp32(float i) { +static inline HVX_Vector hvx_vec_splat_fp32(float v) { union { - float f; - int32_t i; - } fp32 = { .f = i }; + float f; + uint32_t i; + } fp32 = { .f = v }; return Q6_V_vsplat_R(fp32.i); } +static inline HVX_Vector hvx_vec_splat_fp16(float v) { + union { + __fp16 f; + uint16_t i; + } fp16 = { .f = v }; + + return Q6_Vh_vsplat_R(fp16.i); +} + static inline void hvx_vec_store_u(void * addr, uint32_t n, HVX_Vector v) { // Rotate as needed. v = Q6_V_vlalign_VVR(v, v, (size_t) addr); @@ -242,6 +251,120 @@ static inline void hvx_copy_fp32_au(uint8_t * restrict dst, const uint8_t * rest } } +// copy n fp32 elements : source is unaligned, destination unaligned +static inline void hvx_copy_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + HVX_UVector * restrict vdst = (HVX_UVector *) dst; + HVX_UVector * restrict vsrc = (HVX_UVector *) src; + + assert((unsigned long) dst % 128 == 0); + + uint32_t nvec = n / 32; + uint32_t nloe = n % 32; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + HVX_Vector v = vsrc[i]; + vdst[i] = v; + } + + if (nloe) { + HVX_Vector v = vsrc[i]; + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v); + } +} + +// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is unaligned +static inline void hvx_copy_fp16_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16 + HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32 + + const HVX_Vector zero = Q6_V_vsplat_R(0); + + uint32_t nvec = n / 64; + uint32_t nloe = n % 64; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + vdst[i] = Q6_Vh_vdeal_Vh(s_hf); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf)); + } +} + +// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is unaligned +static inline void hvx_copy_fp16_fp32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16 + HVX_Vector * restrict vsrc = (HVX_Vector *) src; // fp32 + + const HVX_Vector zero = Q6_V_vsplat_R(0); + + uint32_t nvec = n / 64; + uint32_t nloe = n % 64; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + vdst[i] = Q6_Vh_vdeal_Vh(s_hf); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf)); + } +} + +// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is aligned +static inline void hvx_copy_fp16_fp32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + HVX_Vector * restrict vdst = (HVX_Vector *) dst; // fp16 + HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32 + + const HVX_Vector zero = Q6_V_vsplat_R(0); + + uint32_t nvec = n / 64; + uint32_t nloe = n % 64; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + vdst[i] = Q6_Vh_vdeal_Vh(s_hf); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf)); + } +} + // bcast 1 fp32 element from source to n fp32 elements in destination : destination is aligned static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t n) { HVX_Vector * restrict vdst = (HVX_Vector *) dst; @@ -273,8 +396,6 @@ static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint3 return right_off <= chunk_size; } - - static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) { HVX_VectorAlias u = { .v = v }; @@ -531,13 +652,13 @@ static inline HVX_Vector hvx_vec_abs_fp32(HVX_Vector v) { } static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) { -#if __HTP_ARCH__ > 75 +#if __HVX_ARCH__ > 75 return Q6_Vsf_vfneg_Vsf(v); #else // neg by setting the fp32 sign bit HVX_Vector mask = Q6_V_vsplat_R(0x80000000); return Q6_V_vxor_VV(v, mask); -#endif // __HTP_ARCH__ > 75 +#endif // __HVX_ARCH__ > 75 } // ==================================================== @@ -976,6 +1097,24 @@ static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v, return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero()); } +static inline HVX_Vector hvx_vec_tanh_fp32(HVX_Vector x) { + // tanh(x) = 2 * sigmoid(2x) - 1 + HVX_Vector two = hvx_vec_splat_fp32(2.0f); + HVX_Vector one = hvx_vec_splat_fp32(1.0f); + HVX_Vector x2 = Q6_Vqf32_vmpy_VsfVsf(x, two); + + static const float kMinExp = -87.f; // 0 + static const float kMaxExp = 87.f; // 1 + HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp); + HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp); + + HVX_Vector sig2x = hvx_vec_fast_sigmoid_fp32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp); + + HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two); + res = Q6_Vqf32_vsub_Vqf32Vsf(res, one); + return Q6_Vsf_equals_Vqf32(res); +} + static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) { int step_of_1 = num_elems >> 5; int remaining = num_elems - step_of_1 * VLEN_FP32; @@ -1056,6 +1195,115 @@ static inline void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restr } } +static inline void hvx_scale_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + int nvec = n / VLEN_FP32; + int nloe = n % VLEN_FP32; + + HVX_Vector vs = hvx_vec_splat_fp32(scale); + + HVX_Vector * vsrc = (HVX_Vector *) src; + HVX_Vector * vdst = (HVX_Vector *) dst; + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; ++i) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); + vdst[i] = Q6_Vsf_equals_Vqf32(v); + } + + if (nloe) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); + hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); + } +} + +static inline void hvx_scale_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + int nvec = n / VLEN_FP32; + int nloe = n % VLEN_FP32; + + HVX_Vector vs = hvx_vec_splat_fp32(scale); + + HVX_UVector * vsrc = (HVX_UVector *) src; + HVX_UVector * vdst = (HVX_UVector *) dst; + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; ++i) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); + vdst[i] = Q6_Vsf_equals_Vqf32(v); + } + + if (nloe) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); + hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); + } +} + +static inline void hvx_scale_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) { + hvx_scale_f32_aa(dst, src, n, scale); + } else { + hvx_scale_f32_uu(dst, src, n, scale); + } +} + +static inline void hvx_scale_offset_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + int nvec = n / VLEN_FP32; + int nloe = n % VLEN_FP32; + + HVX_Vector vs = hvx_vec_splat_fp32(scale); + HVX_Vector vo = hvx_vec_splat_fp32(offset); + + HVX_Vector * vsrc = (HVX_Vector *) src; + HVX_Vector * vdst = (HVX_Vector *) dst; + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; ++i) { + HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); + vdst[i] = Q6_Vsf_equals_Vqf32(v); + } + + if (nloe) { + HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); + hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); + } +} + +static inline void hvx_scale_offset_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + int nvec = n / VLEN_FP32; + int nloe = n % VLEN_FP32; + + HVX_Vector vs = hvx_vec_splat_fp32(scale); + HVX_Vector vo = hvx_vec_splat_fp32(offset); + + HVX_UVector * vsrc = (HVX_UVector *) src; + HVX_UVector * vdst = (HVX_UVector *) dst; + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; ++i) { + HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); + vdst[i] = Q6_Vsf_equals_Vqf32(v); + } + + if (nloe) { + HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); + hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); + } +} + +static inline void hvx_scale_offset_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) { + hvx_scale_offset_f32_aa(dst, src, n, scale, offset); + } else { + hvx_scale_offset_f32_uu(dst, src, n, scale, offset); + } +} float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems); void hvx_mul_f32(const uint8_t * restrict src0, @@ -1090,7 +1338,6 @@ void hvx_sub_f32_opt(const uint8_t * restrict src0, uint8_t * restrict dst, const int num_elems); void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems); -void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale); void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems); void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems); void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate); diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index fb5508a560..24b3e90e4b 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -443,6 +443,45 @@ static void proc_matmul_req(struct htp_context * ctx, send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[2].fd; + rsp_bufs[0].ptr = bufs[2].ptr; + rsp_bufs[0].offset = bufs[2].offset; + rsp_bufs[0].size = bufs[2].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.dst.data = (uint32_t) bufs[2].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_get_rows(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_matmul_id_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs, @@ -668,7 +707,7 @@ static void proc_rope_req(struct htp_context * ctx, uint32_t n_bufs) { struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - int write_idx = (n_bufs == 4) ? 3 : 2; + int write_idx = n_bufs - 1; // We had written to the output buffer, we'd also need to flush it rsp_bufs[0].fd = bufs[write_idx].fd; @@ -716,6 +755,102 @@ static void proc_rope_req(struct htp_context * ctx, send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_set_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[2].fd; + rsp_bufs[0].ptr = bufs[2].ptr; + rsp_bufs[0].offset = bufs[2].offset; + rsp_bufs[0].size = bufs[2].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.dst.data = (uint32_t) bufs[2].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_set_rows(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void proc_flash_attn_ext_req(struct htp_context * ctx, + struct htp_general_req * req, + struct dspqueue_buffer * bufs, + uint32_t n_bufs) { + // Setup Op context + struct htp_ops_context octx; + memset(&octx, 0, sizeof(octx)); + + octx.ctx = ctx; + octx.n_threads = ctx->n_threads; + + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.src2 = req->src2; + octx.src3 = req->src3; + octx.src4 = req->src4; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.src2.data = (uint32_t) bufs[2].ptr; + + int last_buf = 3; + + if (octx.src3.ne[0]) { + octx.src3.data = (uint32_t) bufs[last_buf++].ptr; // mask is valid + } + + if (octx.src4.ne[0]) { + octx.src4.data = (uint32_t) bufs[last_buf++].ptr; // sinks is valid + } + + octx.dst.data = (uint32_t) bufs[last_buf].ptr; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_flash_attn_ext(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + + struct dspqueue_buffer rsp_buf = bufs[last_buf]; + rsp_buf.flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof); +} + static void htp_packet_callback(dspqueue_t queue, int error, void * context) { struct htp_context * ctx = (struct htp_context *) context; @@ -790,6 +925,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { break; case HTP_OP_RMS_NORM: + case HTP_OP_SCALE: if (n_bufs != 2) { FARF(ERROR, "Bad unary-req buffer list"); continue; @@ -833,6 +969,30 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_rope_req(ctx, &req, bufs, n_bufs); break; + case HTP_OP_FLASH_ATTN_EXT: + if (!(n_bufs >= 4 && n_bufs <= 6)) { + FARF(ERROR, "Bad flash-attn-ext-req buffer list"); + continue; + } + proc_flash_attn_ext_req(ctx, &req, bufs, n_bufs); + break; + + case HTP_OP_SET_ROWS: + if (n_bufs != 3) { + FARF(ERROR, "Bad set-rows-req buffer list"); + continue; + } + proc_set_rows_req(ctx, &req, bufs); + break; + + case HTP_OP_GET_ROWS: + if (n_bufs != 3) { + FARF(ERROR, "Bad get-rows-req buffer list"); + continue; + } + proc_get_rows_req(ctx, &req, bufs); + break; + default: FARF(ERROR, "Unknown Op %u", req.op); break; diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index f14523d485..9bb39db9fc 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -26,14 +26,14 @@ #include "hvx-utils.h" #include "ops-utils.h" +#define MM_SPAD_SRC0_NROWS 16 +#define MM_SPAD_SRC1_NROWS 16 +#define MM_SPAD_DST_NROWS 2 + struct htp_matmul_type { const char * type; void (*vec_dot)(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); - void (*vec_dot_rx2)(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy); + void (*vec_dot_rx2)(const int n, float * restrict s, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy); }; typedef struct { @@ -907,145 +907,174 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); } -#if 1 -static void vec_dot_f16_f32(const int n, float * restrict s, const void * restrict x, const void * restrict y) { - if (0) { - float rsum = 0; - const __fp16 * restrict vx = (const __fp16 * restrict) x; - const float * restrict vy = (const float * restrict) y; +static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const HVX_Vector * restrict x = (const HVX_Vector *) vx; + const HVX_Vector * restrict y = (const HVX_Vector *) vy; - for (uint32_t i = 0; i < n; i++) { - rsum += (float)vx[i] * vy[i]; - } - *s = rsum; - return; - } + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements - const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; - const HVX_UVectorPair * restrict vy = (const HVX_UVectorPair * restrict) y; + HVX_Vector rsum = Q6_V_vsplat_R(0); - uint32_t nv0 = n / 64; // num full fp16 hvx vectors - uint32_t nv1 = n % 64; // leftover elements - - // for some reason we need volatile here so that the compiler doesn't try anything funky - volatile HVX_Vector rsum = Q6_V_vsplat_R(0); - float r_sum_scalar = 0.0f; uint32_t i = 0; - for (i = 0; i < nv0; i++) { - HVX_VectorPair yp = vy[i]; - - HVX_Vector x = vx[i]; - HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0 - - //NOTE: need volatile here to prevent compiler optimization - // Seem compiler cannot guarantee read-after-write?? - volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp)); - volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp)); - - HVX_Vector sum = Q6_Vqf32_vadd_Vqf32Vqf32(hi, lo); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum); + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - if (nv1) { - // HVX_VectorPair yp = vy[i]; + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); - // HVX_Vector x = vx[i]; - // HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0 - - // if (nv1 >= 32) { - // volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp)); - // rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi); - // nv1 -= 32; - // } - - // rsum = hvx_vec_qf32_reduce_sum(rsum); - - // if (nv1) { - // volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp)); - // HVX_Vector sum = hvx_vec_qf32_reduce_sum_n(lo, nv1); - // rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum); - // } - - //process the remainder using scalar loop - rsum = hvx_vec_qf32_reduce_sum(rsum); - const __fp16 * restrict sx = (const __fp16 * restrict) x; - const float * restrict sy = (const float * restrict) y; - - for (uint32_t i = nv0 * 64; i < n; i++) { - r_sum_scalar += (float) sx[i] * sy[i]; - } - - // hvx_vec_dump_fp16("X", x); - // hvx_vec_dump_fp16("Y", y); - // hvx_vec_dump_fp32("SUM", Q6_Vsf_equals_Vqf32(sum)); - // hvx_vec_dump_fp32("RSUM", Q6_Vsf_equals_Vqf32(rsum)); - } else { - rsum = hvx_vec_qf32_reduce_sum(rsum); + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - *s = hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(rsum)) + r_sum_scalar; - -# ifdef HTP_DEBUG - { - float rsum = 0; - const __fp16 * restrict vx = (const __fp16 * restrict) x; - const float * restrict vy = (const float * restrict) y; - - for (uint32_t i = 0; i < n; i++) { - rsum += vx[i] * vy[i]; - } - - float diff = fabs(*s - rsum); - if (diff > 0.001) { - FARF(HIGH, "vec-dot-f16-missmatch: %u (%u:%u) expected %.6f got %.6f\n", n, nv0, nv1, rsum, *s); - // htp_dump_f16("x", vx, n); - // htp_dump_f32("y", vy, n); - } - } -# endif + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + hvx_vec_store_u(&s[0], 4, rsum); } -#else -static void vec_dot_f16_f32(const int n, float * restrict s, const void * restrict x, const void * restrict y) { - const uint32_t fk = 64; - const uint32_t nb = n / fk; - assert(n % fk == 0); - assert(nb % 4 == 0); +static void vec_dot_f16_f16_aa_rx2(const int n, + float * restrict s, + const void * restrict vx, + uint32_t vx_row_size, + const void * restrict vy) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx; + const HVX_Vector * restrict x1 = (const HVX_Vector *) ((const uint8_t *) vx + vx_row_size); + const HVX_Vector * restrict y = (const HVX_Vector *) vy; - const uint32_t x_blk_size = 2 * fk; // fp16 - const uint32_t y_blk_size = 4 * fk; // fp32 + uint32_t nvec = n / VLEN_FP16; + uint32_t nloe = n % VLEN_FP16; - // Row sum (qf32) HVX_Vector rsum0 = Q6_V_vsplat_R(0); HVX_Vector rsum1 = Q6_V_vsplat_R(0); - HVX_Vector rsum2 = Q6_V_vsplat_R(0); - HVX_Vector rsum3 = Q6_V_vsplat_R(0); - for (uint32_t i = 0; i < nb; i += 4) { - HVX_Vector_x4 vx = hvx_vec_load_x4_f16(x + (i * x_blk_size)); - HVX_Vector_x4 vy = hvx_vec_load_x4_f32_as_f16(y + (i * y_blk_size)); + uint32_t i = 0; - HVX_VectorPair fa0 = Q6_Wqf32_vmpy_VhfVhf(vx.v[0], vy.v[0]); - HVX_VectorPair fa1 = Q6_Wqf32_vmpy_VhfVhf(vx.v[1], vy.v[1]); - HVX_VectorPair fa2 = Q6_Wqf32_vmpy_VhfVhf(vx.v[2], vy.v[2]); - HVX_VectorPair fa3 = Q6_Wqf32_vmpy_VhfVhf(vx.v[3], vy.v[3]); + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector y_hf = y[i]; + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0[i], y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1[i], y_hf); - rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa0), Q6_V_hi_W(fa0))); - rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa1), Q6_V_hi_W(fa1))); - rsum2 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum2, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa2), Q6_V_hi_W(fa2))); - rsum3 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum3, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa3), Q6_V_hi_W(fa3))); + rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf))); + rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); } - // Reduce and convert into fp32 - rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, rsum1); - rsum2 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum2, rsum3); - HVX_Vector rsum = hvx_vec_qf32_reduce_sum(Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, rsum2)); - hvx_vec_store_u(s, 4, Q6_Vsf_equals_Vqf32(rsum)); -} -#endif + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); -#define htp_matmul_preamble \ + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + + rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf))); + rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); + } + + rsum0 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum0)); + rsum1 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum1)); + HVX_VectorPair p0 = Q6_W_vshuff_VVR(rsum1, rsum0, 4); + + hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); +} + +static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const HVX_UVector * restrict x = (const HVX_UVector *) vx; + const HVX_UVector * restrict y = (const HVX_UVector *) vy; + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + HVX_Vector rsum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + hvx_vec_store_u(&s[0], 4, rsum); +} + +static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * restrict x, const void * restrict y) { + const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; + const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y; + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + + HVX_Vector rsum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements + HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements + HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + + // Load x (fp16) + HVX_Vector x_hf = vx[i]; + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements + HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements + HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + + // Load x (fp16) + HVX_Vector x_hf = vx[i]; + + // Zero-out unused elements + // Note that we need to clear both x and y because they may contain NANs + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + x_hf = Q6_V_vand_QV(bmask, x_hf); + y_hf = Q6_V_vand_QV(bmask, y_hf); + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + hvx_vec_store_u(&s[0], 4, rsum); +} + +#define htp_matmul_tensors_preamble \ + struct htp_tensor * restrict src0 = &octx->src0; \ + struct htp_tensor * restrict src1 = &octx->src1; \ + struct htp_tensor * restrict src2 = &octx->src2; \ + struct htp_tensor * restrict dst = &octx->dst; \ + struct htp_spad * restrict src0_spad = &octx->src0_spad; \ + struct htp_spad * restrict src1_spad = &octx->src1_spad; \ + struct htp_spad * restrict dst_spad = &octx->dst_spad; \ + \ const uint32_t ne00 = src0->ne[0]; \ const uint32_t ne01 = src0->ne[1]; \ const uint32_t ne02 = src0->ne[2]; \ @@ -1056,6 +1085,11 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri const uint32_t ne12 = src1->ne[2]; \ const uint32_t ne13 = src1->ne[3]; \ \ + const uint32_t ne20 = src2->ne[0]; \ + const uint32_t ne21 = src2->ne[1]; \ + const uint32_t ne22 = src2->ne[2]; \ + const uint32_t ne23 = src2->ne[3]; \ + \ const uint32_t ne0 = dst->ne[0]; \ const uint32_t ne1 = dst->ne[1]; \ const uint32_t ne2 = dst->ne[2]; \ @@ -1076,18 +1110,94 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -// q8x4 src1 tensor is already in VTCM spad -static void matmul(struct htp_matmul_type * mt, - struct htp_tensor * restrict src0, - struct htp_tensor * restrict src1, - struct htp_tensor * restrict dst, - struct htp_spad * restrict src0_spad, - struct htp_spad * restrict src1_spad, - struct htp_spad * restrict dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +#define htp_matmul_preamble \ + htp_matmul_tensors_preamble; \ + dma_queue *dma_queue = octx->ctx->dma[ith]; \ + uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + +// *** matmul with support for 4d tensors and full broadcasting + +static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { + htp_matmul_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + assert(ne12 % ne02 == 0); + assert(ne13 % ne03 == 0); + + // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) + const uint32_t nr0 = ne0; + + // This is the size of the rest of the dimensions of the result + const uint32_t nr1 = ne1 * ne2 * ne3; + + // distribute the thread work across the inner or outer loop based on which one is larger + uint32_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows + uint32_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows + + // The number of elements in each chunk + const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; + const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; + + uint32_t current_chunk = ith; + + const uint32_t ith0 = current_chunk % nchunk0; + const uint32_t ith1 = current_chunk / nchunk0; + + const uint32_t ir0_start = dr0 * ith0; + const uint32_t ir0_end = MIN(ir0_start + dr0, nr0); + + const uint32_t ir1_start = dr1 * ith1; + const uint32_t ir1_end = MIN(ir1_start + dr1, nr1); + + // no work for this thread + if (ir0_start >= ir0_end || ir1_start >= ir1_end) { + return; + } + + // block-tiling attempt + const uint32_t blck_0 = 64; + const uint32_t blck_1 = 64; + + for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { + for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { + for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) { + const uint32_t i13 = fastdiv(ir1, &octx->mm_div_ne12_ne1); + const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &octx->mm_div_ne1); + const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); + + // broadcast src0 into src1 + const uint32_t i03 = fastdiv(i13, &octx->mm_div_r3); + const uint32_t i02 = fastdiv(i12, &octx->mm_div_r2); + + const uint32_t i1 = i11; + const uint32_t i2 = i12; + const uint32_t i3 = i13; + + const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03); + const uint8_t * restrict src1_col = (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13); + float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); + + const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end); + for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) { + const uint8_t * restrict src0_row = src0_base + ir0 * nb01; + mt->vec_dot(ne00, &dst_col[ir0], src0_row, src1_col); + } + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "matmul-4d %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0], + src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// src1 tensor is already in VTCM spad +static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { htp_matmul_preamble; const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows @@ -1104,9 +1214,10 @@ static void matmul(struct htp_matmul_type * mt, const size_t dst_row_size = nb1; const size_t src0_row_size = nb01; - const size_t src1_row_size = q8x4x2_row_size(ne10); + const size_t src1_row_size = nb11; - const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128); + const size_t src0_stride = src0_spad->stride; + const size_t src1_stride = src1_spad->stride; // Per-thread VTCM scratchpads for all tensors // Note that the entire src1 tensor is already in VTCM @@ -1124,11 +1235,11 @@ static void matmul(struct htp_matmul_type * mt, #pragma unroll(4) for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const int is0 = (ir0 - src0_start_row); - if (is0 >= HTP_SPAD_SRC0_NROWS) { + if (is0 >= MM_SPAD_SRC0_NROWS) { break; } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 2); } // Process src0 rows @@ -1137,17 +1248,17 @@ static void matmul(struct htp_matmul_type * mt, #pragma unroll(2) for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { - const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_row_size); + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); - mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col); + mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_stride, src1_col); } // Prefetch next (n + spad_nrows) row - const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); - const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; + const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, 2); } } @@ -1155,13 +1266,13 @@ static void matmul(struct htp_matmul_type * mt, if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; const int is0 = (ir0 - src0_start_row); - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 1); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; #pragma unroll(2) for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { - const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_row_size); + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); } @@ -1176,17 +1287,7 @@ static void matmul(struct htp_matmul_type * mt, } // q8x4x2 src1 tensor is already in VTCM spad -static void matvec(struct htp_matmul_type * mt, - struct htp_tensor * restrict src0, - struct htp_tensor * restrict src1, - struct htp_tensor * restrict dst, - struct htp_spad * restrict src0_spad, - struct htp_spad * restrict src1_spad, - struct htp_spad * restrict dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { htp_matmul_preamble; const uint32_t src0_nrows = ne01; @@ -1202,9 +1303,10 @@ static void matvec(struct htp_matmul_type * mt, const size_t dst_row_size = nb1; const size_t src0_row_size = nb01; - const size_t src1_row_size = q8x4x2_row_size(ne10); + const size_t src1_row_size = nb11; - const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128); + const size_t src0_stride = src0_spad->stride; + const size_t src1_stride = src1_spad->stride; // Per-thread VTCM scratchpads for all tensors // Note that the entire src1 tensor is already in VTCM @@ -1226,24 +1328,24 @@ static void matvec(struct htp_matmul_type * mt, #pragma unroll(2) for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const uint32_t is0 = (ir0 - src0_start_row); - if (is0 >= HTP_SPAD_SRC0_NROWS) { + if (is0 >= MM_SPAD_SRC0_NROWS) { break; } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 2); } // Process src0 rows for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_row_size_padded, src1_col); + mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_stride, src1_col); // Prefetch next (n + spad_nrows) row - const uint32_t pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); - const uint32_t is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; + const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, 2); } } @@ -1251,8 +1353,8 @@ static void matvec(struct htp_matmul_type * mt, if (src0_end_row != src0_end_row_x2) { const uint32_t ir0 = src0_end_row_x2; const uint32_t is0 = (ir0 - src0_start_row); - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 1); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); } @@ -1274,22 +1376,13 @@ struct mmid_row_mapping { uint32_t i2; }; -// q8x4 src1 tensor is already in VTCM spad -static void matmul_id(struct htp_matmul_type * mt, - struct htp_tensor * restrict src0, - struct htp_tensor * restrict src1, - struct htp_tensor * restrict ids, - struct htp_tensor * restrict dst, - struct htp_spad * restrict src0_spad, - struct htp_spad * restrict src1_spad, - struct htp_spad * restrict src2_spad, - struct htp_spad * restrict dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +// src1 tensor is already in VTCM spad +static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { htp_matmul_preamble; + struct htp_tensor * restrict ids = &octx->src2; + struct htp_spad * restrict src2_spad = &octx->src2_spad; + uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -1340,7 +1433,7 @@ static void matmul_id(struct htp_matmul_type * mt, #pragma unroll(4) for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const int is0 = (ir0 - src0_start_row); - if (is0 >= HTP_SPAD_SRC0_NROWS) { + if (is0 >= MM_SPAD_SRC0_NROWS) { break; } dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), @@ -1365,8 +1458,8 @@ static void matmul_id(struct htp_matmul_type * mt, } // Prefetch next (n + spad_nrows) row - const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); - const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; + const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; if (pr0 < src0_end_row_x2) { dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), src0_row_size_padded, src0_row_size, 2); @@ -1404,22 +1497,13 @@ static void matmul_id(struct htp_matmul_type * mt, dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -// q8x4 src1 tensor is already in VTCM spad -static void matvec_id(struct htp_matmul_type * mt, - struct htp_tensor * restrict src0, - struct htp_tensor * restrict src1, - struct htp_tensor * restrict src2, - struct htp_tensor * restrict dst, - struct htp_spad * restrict src0_spad, - struct htp_spad * restrict src1_spad, - struct htp_spad * restrict src2_spad, - struct htp_spad * restrict dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +// src1 tensor is already in VTCM spad +static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { htp_matmul_preamble; + struct htp_tensor * restrict ids = &octx->src2; + struct htp_spad * restrict src2_spad = &octx->src2_spad; + uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -1464,7 +1548,7 @@ static void matvec_id(struct htp_matmul_type * mt, #pragma unroll(4) for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const int is0 = (ir0 - src0_start_row); - if (is0 >= HTP_SPAD_SRC0_NROWS) { + if (is0 >= MM_SPAD_SRC0_NROWS) { break; } dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), @@ -1477,8 +1561,8 @@ static void matvec_id(struct htp_matmul_type * mt, mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col); // Prefetch next (n + spad_nrows) row - const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); - const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; + const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; if (pr0 < src0_end_row_x2) { dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), src0_row_size_padded, src0_row_size, 2); @@ -1504,106 +1588,6 @@ static void matvec_id(struct htp_matmul_type * mt, dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -// *** matmul in fp16 - -static void matmul_f16_f32(struct htp_tensor * restrict src0, - struct htp_tensor * restrict src1, - struct htp_tensor * restrict dst, - struct htp_spad * restrict src0_spad, - struct htp_spad * restrict src1_spad, - struct htp_spad * restrict dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { - htp_matmul_preamble; - - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); - - assert(ne12 % ne02 == 0); - assert(ne13 % ne03 == 0); - - // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) - const uint32_t nr0 = ne0; - - // This is the size of the rest of the dimensions of the result - const uint32_t nr1 = ne1 * ne2 * ne3; - - // distribute the thread work across the inner or outer loop based on which one is larger - uint32_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows - uint32_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows - - // The number of elements in each chunk - const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; - const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; - - uint32_t current_chunk = ith; - - const uint32_t ith0 = current_chunk % nchunk0; - const uint32_t ith1 = current_chunk / nchunk0; - - const uint32_t ir0_start = dr0 * ith0; - const uint32_t ir0_end = MIN(ir0_start + dr0, nr0); - - const uint32_t ir1_start = dr1 * ith1; - const uint32_t ir1_end = MIN(ir1_start + dr1, nr1); - - // broadcast factors - const uint32_t r2 = ne12 / ne02; - const uint32_t r3 = ne13 / ne03; - - // no work for this thread - if (ir0_start >= ir0_end || ir1_start >= ir1_end) { - return; - } - - // block-tiling attempt - const uint32_t blck_0 = 64; - const uint32_t blck_1 = 64; - - __attribute__((aligned(128))) float tmp[64]; - - for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { - for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { - for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) { - const uint32_t i13 = (ir1 / (ne12 * ne1)); - const uint32_t i12 = (ir1 - i13 * ne12 * ne1) / ne1; - const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); - - // broadcast src0 into src1 - const uint32_t i03 = i13 / r3; - const uint32_t i02 = i12 / r2; - - const uint32_t i1 = i11; - const uint32_t i2 = i12; - const uint32_t i3 = i13; - - const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03); - const uint8_t * restrict src1_col = - (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13); - float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); - - const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end); - for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) { - // Use nb01 stride for non-contiguous src0 support - const uint8_t * restrict src0_row = src0_base + ir0 * nb01; - vec_dot_f16_f32(ne00, &tmp[ir0 - iir0], src0_row, src1_col); - } - - hvx_copy_fp32_ua((uint8_t *) &dst_col[iir0], (uint8_t *) tmp, MIN(iir0 + blck_0, ir0_end) - iir0); - } - } - } - - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "matmul-f16-f32 %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0], - src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); -} - // *** dynamic quant static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { @@ -1780,20 +1764,14 @@ static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, u for (uint32_t i = 0; i < nb; i++) { #if FP32_QUANTIZE_GROUP_SIZE == 32 - quantize_block_fp32_q8x1(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2, - t_d + (i * 2 + 0) * dblk_size / 2); - quantize_block_fp32_q8x1(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2, - t_d + (i * 2 + 1) * dblk_size / 2); + quantize_block_fp32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_fp32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); #elif FP32_QUANTIZE_GROUP_SIZE == 64 - quantize_block_fp32_q8x2(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2, - t_d + (i * 2 + 0) * dblk_size / 2); - quantize_block_fp32_q8x2(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2, - t_d + (i * 2 + 1) * dblk_size / 2); + quantize_block_fp32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_fp32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); #elif FP32_QUANTIZE_GROUP_SIZE == 128 - quantize_block_fp32_q8x4(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2, - t_d + (i * 2 + 0) * dblk_size / 2); - quantize_block_fp32_q8x4(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2, - t_d + (i * 2 + 1) * dblk_size / 2); + quantize_block_fp32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_fp32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); #else #error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128" #endif @@ -1848,14 +1826,95 @@ static void quantize_fp32_q8x4x2(const struct htp_tensor * src, ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } +static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith, + uint32_t nrows_per_thread, uint32_t dst_stride) { + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = ne0 * sizeof(float); + const size_t src_stride = src->nb[1]; + + uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); + + for (uint32_t i = ir_first; i < ir_last; ++i) { + htp_l2fetch(src_data, 2, src_row_size, src_stride); + hvx_copy_fp16_fp32_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "quantize-fp32-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// TODO just a plain copy that should be done via the DMA during the Op setup +static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith, + uint32_t nrows_per_thread, uint32_t dst_stride) { + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = ne0 * sizeof(float); + const size_t src_stride = src->nb[1]; + + uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); + + for (uint32_t i = ir_first; i < ir_last; ++i) { + htp_l2fetch(src_data, 2, src_row_size, src_stride); + hvx_copy_fp16_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "quantize-fp16-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + static void htp_quantize_fp32_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; quantize_fp32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread); } -// ** matmul callbacks for worker_pool +static void htp_quantize_fp32_fp16(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + quantize_fp32_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride); +} -static void htp_matvec_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_quantize_fp16_fp16(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + quantize_fp16_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride); +} + +// ** matmul/matvec callbacks for worker_pool + +static void htp_matvec_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1863,11 +1922,10 @@ static void htp_matvec_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data mt.vec_dot = vec_dot_q4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_2d(&mt, octx, n, i); } -static void htp_matmul_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_matmul_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1875,11 +1933,10 @@ static void htp_matmul_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data mt.vec_dot = vec_dot_q4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_2d(&mt, octx, n, i); } -static void htp_matvec_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_matvec_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1887,11 +1944,10 @@ static void htp_matvec_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data mt.vec_dot = vec_dot_q8x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_2d(&mt, octx, n, i); } -static void htp_matmul_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_matmul_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1899,11 +1955,10 @@ static void htp_matmul_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data mt.vec_dot = vec_dot_q8x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_2d(&mt, octx, n, i); } -static void htp_matvec_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_matvec_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1911,11 +1966,10 @@ static void htp_matvec_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_2d(&mt, octx, n, i); } -static void htp_matmul_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_matmul_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1923,14 +1977,49 @@ static void htp_matmul_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_2d(&mt, octx, n, i); } -static void htp_matmul_f16_f32(unsigned int n, unsigned int i, void * data) { +static void htp_matvec_2d_f16_f16(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; - matmul_f16_f32(&octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + + struct htp_matmul_type mt; + mt.type = "f16-f16"; + mt.vec_dot = vec_dot_f16_f16_aa; + mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2; + + matvec_2d(&mt, octx, n, i); +} + +static void htp_matmul_2d_f16_f16(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "f16-f16"; + mt.vec_dot = vec_dot_f16_f16_aa; + mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2; + + matmul_2d(&mt, octx, n, i); +} + +static void htp_matmul_4d_f16_f32(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "f16-f32"; + mt.vec_dot = vec_dot_f16_f32_uu; + + matmul_4d(&mt, octx, n, i); +} + +static void htp_matmul_4d_f16_f16(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "f16-f16"; + mt.vec_dot = vec_dot_f16_f16_uu; + + matmul_4d(&mt, octx, n, i); } // ** matmul-id callbacks for worker_pool @@ -1943,8 +2032,7 @@ static void htp_matvec_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_q4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_id(&mt, octx, n, i); } static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { @@ -1955,8 +2043,7 @@ static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_q4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_id(&mt, octx, n, i); } static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { @@ -1967,8 +2054,7 @@ static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_q8x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_id(&mt, octx, n, i); } static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { @@ -1979,8 +2065,7 @@ static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_q8x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_id(&mt, octx, n, i); } static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { @@ -1991,8 +2076,7 @@ static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_id(&mt, octx, n, i); } static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { @@ -2003,18 +2087,17 @@ static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_id(&mt, octx, n, i); } // ** main matmul entry point -int op_matmul(struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; +static inline bool htp_is_permuted(const struct htp_tensor * t) { + return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3]; +} - htp_matmul_preamble; +int op_matmul(struct htp_ops_context * octx) { + htp_matmul_tensors_preamble; const char * op_type; @@ -2038,9 +2121,9 @@ int op_matmul(struct htp_ops_context * octx) { op_type = "q4x4x2-fp32"; quant_job_func = htp_quantize_fp32_q8x4x2; if (src1_nrows > 1) { - matmul_job_func = htp_matmul_q4x4x2_q8x4x2; + matmul_job_func = htp_matmul_2d_q4x4x2_q8x4x2; } else { - matmul_job_func = htp_matvec_q4x4x2_q8x4x2; + matmul_job_func = htp_matvec_2d_q4x4x2_q8x4x2; } src1_row_size = q8x4x2_row_size(ne10); // row size post quantization @@ -2048,8 +2131,8 @@ int op_matmul(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); // src0 spad is also used in dynamic quantizer to store padded src1 rows @@ -2067,9 +2150,9 @@ int op_matmul(struct htp_ops_context * octx) { op_type = "q8x4x2-fp32"; quant_job_func = htp_quantize_fp32_q8x4x2; if (src1_nrows > 1) { - matmul_job_func = htp_matmul_q8x4x2_q8x4x2; + matmul_job_func = htp_matmul_2d_q8x4x2_q8x4x2; } else { - matmul_job_func = htp_matvec_q8x4x2_q8x4x2; + matmul_job_func = htp_matvec_2d_q8x4x2_q8x4x2; } src1_row_size = q8x4x2_row_size(ne10); // row size post quantization @@ -2077,8 +2160,8 @@ int op_matmul(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); // src0 spad is also used in dynamic quantizer to store padded src1 rows @@ -2096,9 +2179,9 @@ int op_matmul(struct htp_ops_context * octx) { op_type = "mxfp4x4x2-f32"; quant_job_func = htp_quantize_fp32_q8x4x2; if (src1_nrows > 1) { - matmul_job_func = htp_matmul_mxfp4x4x2_q8x4x2; + matmul_job_func = htp_matmul_2d_mxfp4x4x2_q8x4x2; } else { - matmul_job_func = htp_matvec_mxfp4x4x2_q8x4x2; + matmul_job_func = htp_matvec_2d_mxfp4x4x2_q8x4x2; } src1_row_size = q8x4x2_row_size(ne10); // row size post quantization @@ -2106,8 +2189,8 @@ int op_matmul(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); // src0 spad is also used in dynamic quantizer to store padded src1 rows @@ -2122,20 +2205,69 @@ int op_matmul(struct htp_ops_context * octx) { break; case HTP_TYPE_F16: - op_type = "f16-f32"; - quant_job_func = NULL; // htp_quantize_f32_f16; - matmul_job_func = htp_matmul_f16_f32; + { + // Try optimized f16-f16 path first (src1 in VTCM) + const size_t f16_src1_row_size = htp_round_up(ne10 * 2, 128); + const size_t f16_src1_spad_size = htp_round_up(f16_src1_row_size * src1_nrows, 256); + const size_t f16_src0_spad_size = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; + const size_t f16_dst_spad_size = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; - // For all tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size, 256); - octx->src1_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC1_NROWS * src1_row_size, 256); + const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting). + // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul. + const bool is_batched = (ne02 > 1) || (ne03 > 1); + const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1); - need_quant = false; + if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) { + // Optimized path + op_type = "f16-f16"; + quant_job_func = (src1->type == HTP_TYPE_F32) ? htp_quantize_fp32_fp16 : htp_quantize_fp16_fp16; + if (src1_nrows > 1) { + matmul_job_func = htp_matmul_2d_f16_f16; + } else { + matmul_job_func = htp_matvec_2d_f16_f16; + } + + src1_row_size = f16_src1_row_size; // row size post quantization + + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); + + octx->src1_spad.size = octx->src1_spad.size_per_thread; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + } else { + // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required + quant_job_func = NULL; + if (src1->type == HTP_TYPE_F32) { + op_type = "f16-f32"; + matmul_job_func = htp_matmul_4d_f16_f32; + } else { + op_type = "f16-f16"; + matmul_job_func = htp_matmul_4d_f16_f16; + } + + src1_row_size = nb11; // original row size in DDR + + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); + octx->src1_spad.size_per_thread = htp_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); + + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + + // Init fastdiv for matmul_4d (supports broadcasting) + octx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); + octx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); + octx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); + octx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); + + need_quant = false; + } + } break; default: @@ -2166,6 +2298,9 @@ int op_matmul(struct htp_ops_context * octx) { octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even + octx->src0_spad.stride = src0_row_size_padded; + octx->src1_spad.stride = src1_row_size; + if (need_quant) { // Run quant jobs const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); @@ -2185,12 +2320,9 @@ int op_matmul(struct htp_ops_context * octx) { // ** main matmul-id entry point int op_matmul_id(struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - const struct htp_tensor * ids = &octx->src2; - struct htp_tensor * dst = &octx->dst; + htp_matmul_tensors_preamble; - htp_matmul_preamble; + struct htp_tensor * restrict ids = &octx->src2; const char * op_type; @@ -2228,8 +2360,8 @@ int op_matmul_id(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256); @@ -2257,8 +2389,8 @@ int op_matmul_id(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256); @@ -2286,8 +2418,8 @@ int op_matmul_id(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256); diff --git a/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/ggml/src/ggml-hexagon/htp/set-rows-ops.c new file mode 100644 index 0000000000..bdd64fcc8f --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/set-rows-ops.c @@ -0,0 +1,168 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#ifdef HTP_DEBUG +# define FARF_HIGH 1 +#endif +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" +#include "ops-utils.h" + +#define set_rows_preamble \ + const uint32_t ne00 = octx->src0.ne[0]; \ + const uint32_t ne01 = octx->src0.ne[1]; \ + const uint32_t ne02 = octx->src0.ne[2]; \ + const uint32_t ne03 = octx->src0.ne[3]; \ + \ + const uint32_t ne10 = octx->src1.ne[0]; \ + const uint32_t ne11 = octx->src1.ne[1]; \ + const uint32_t ne12 = octx->src1.ne[2]; \ + \ + const uint32_t nb01 = octx->src0.nb[1]; \ + const uint32_t nb02 = octx->src0.nb[2]; \ + const uint32_t nb03 = octx->src0.nb[3]; \ + \ + const uint32_t nb10 = octx->src1.nb[0]; \ + const uint32_t nb11 = octx->src1.nb[1]; \ + const uint32_t nb12 = octx->src1.nb[2]; \ + \ + const uint32_t nb1 = octx->dst.nb[1]; \ + const uint32_t nb2 = octx->dst.nb[2]; \ + const uint32_t nb3 = octx->dst.nb[3]; \ + \ + const uint32_t ne1 = octx->dst.ne[1]; \ + \ + const uint32_t nr = ne01; + +static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) { + set_rows_preamble; + + // parallelize by rows of src0 + const uint32_t dr = octx->src0_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; + + const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + + for (uint32_t i03 = 0; i03 < ne03; ++i03) { + for (uint32_t i02 = 0; i02 < ne02; ++i02) { + for (uint32_t i = ir0; i < ir1; ++i) { + const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12); + const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11); + const uint32_t i10 = i; + + const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + + uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; + if (i1 >= ne1) { + // ignore invalid indices + continue; + } + + const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03; + const uintptr_t dst_ptr = octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3; + + // copy row + hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); + } + } + } + + return HTP_STATUS_OK; +} + +static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) { + set_rows_preamble; + + // parallelize by rows of src0 + const uint32_t dr = octx->src0_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; + + const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + + for (uint32_t i03 = 0; i03 < ne03; ++i03) { + for (uint32_t i02 = 0; i02 < ne02; ++i02) { + for (uint32_t i = ir0; i < ir1; ++i) { + const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12); + const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11); + const uint32_t i10 = i; + + const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + + uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; + if (i1 >= ne1) { + // ignore invalid indices + continue; + } + + const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03; + uint8_t* dst_ptr = (uint8_t *) octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3; + + hvx_copy_fp16_fp32_uu(dst_ptr, src0_ptr, ne00); + } + } + } + + return HTP_STATUS_OK; +} + +static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) { + set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i); +} + +static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) { + set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i); +} + +int op_set_rows(struct htp_ops_context * octx) { + set_rows_preamble; + + if (octx->src0.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + octx->set_rows_div_ne12 = init_fastdiv_values(ne12); + octx->set_rows_div_ne11 = init_fastdiv_values(ne11); + + const uint32_t n_jobs = MIN(nr, octx->n_threads); + octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + + switch(octx->dst.type) { + case HTP_TYPE_F32: + worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs); + break; + case HTP_TYPE_F16: + worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs); + break; + default: + return HTP_STATUS_NO_SUPPORT; + } + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c index 5bf0cbf792..80d249a22c 100644 --- a/ggml/src/ggml-hexagon/htp/softmax-ops.c +++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -238,7 +238,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale, (const uint8_t *) mp_f32, slope); } else { - hvx_scale_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale); + hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale); if (mp_f32) { if (softmax_ctx->use_f16) { for (int i = 0; i < ne00; ++i) { @@ -258,7 +258,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct float max = hvx_self_max_f32((const uint8_t *) wp0, ne00); float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max); sum = sum > 0.0 ? (1.0 / sum) : 1; - hvx_scale_f32((const uint8_t *) wp2, (uint8_t *) dp, ne00, sum); + hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum); } } } diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index bb7557b025..8ed1e5b661 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -83,6 +83,31 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, } } +static void scale_htp_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params, + int opt_path) { + float scale = 0.f; + float bias = 0.f; + memcpy(&scale, &op_params[0], sizeof(float)); + memcpy(&bias, &op_params[1], sizeof(float)); + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_local = src + (ir * row_elems); + float * restrict dst_local = dst + (ir * row_elems); + + if (ir + 1 < num_rows) { + htp_l2fetch(src_local + row_elems, 1, row_size, row_size); + } + + hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias); + } +} + static void rms_norm_htp_f32(const float * restrict src, float * restrict dst, uint8_t * restrict spad, @@ -110,7 +135,7 @@ static void rms_norm_htp_f32(const float * restrict src, const float mean = sum / row_elems; const float scale = 1.0f / sqrtf(mean + epsilon); - hvx_scale_f32((const uint8_t *) src_local, (uint8_t *) dst_local, row_elems, scale); + hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale); } } } @@ -162,6 +187,9 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src, case HTP_OP_RMS_NORM: rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); break; + case HTP_OP_SCALE: + scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); + break; default: break; @@ -195,6 +223,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { unary_op_func = unary_job_dispatcher_f32; op_type = "rmsnorm-f32"; break; + case HTP_OP_SCALE: + unary_op_func = unary_job_dispatcher_f32; + op_type = "scale-f32"; + break; default: FARF(ERROR, "Unsupported unary Op %u\n", octx->op); diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 502a4deebc..3c13777b8a 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -550,6 +550,8 @@ struct vk_device_struct { uint64_t max_memory_allocation_size; uint64_t max_buffer_size; uint64_t suballocation_block_size; + uint64_t min_imported_host_pointer_alignment; + bool external_memory_host {}; bool fp16; bool bf16; bool pipeline_robustness; @@ -2410,7 +2412,8 @@ static std::vector ggml_vk_find_memory_properties(const vk::PhysicalDe return indices; } -static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list & req_flags_list) { +static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list & req_flags_list, + void *import_ptr = nullptr) { VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")"); if (size > device->max_buffer_size) { throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit"); @@ -2439,6 +2442,12 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std nullptr, }; + vk::ExternalMemoryBufferCreateInfo external_memory_bci; + if (import_ptr) { + external_memory_bci.handleTypes = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT; + buffer_create_info.setPNext(&external_memory_bci); + } + buf->buffer = device->device.createBuffer(buffer_create_info); vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer); @@ -2453,35 +2462,80 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std mem_flags_info.setPNext(&mem_priority_info); } - for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) { - const auto & req_flags = *it; - - const std::vector memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags); - - if (memory_type_indices.empty()) { - continue; + if (import_ptr) { + vk::MemoryHostPointerPropertiesEXT host_pointer_props; + try { + host_pointer_props = device->device.getMemoryHostPointerPropertiesEXT(vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT, import_ptr); + } catch (vk::SystemError& e) { + GGML_LOG_WARN("ggml_vulkan: Failed getMemoryHostPointerPropertiesEXT (%s)\n", e.what()); + device->device.destroyBuffer(buf->buffer); + return {}; } - buf->memory_property_flags = req_flags; + vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties(); - bool done = false; + uint32_t memory_type_idx; + vk::MemoryPropertyFlags property_flags = *req_flags_list.begin(); + for (memory_type_idx = 0; memory_type_idx < 32; ++memory_type_idx) { + if (!(host_pointer_props.memoryTypeBits & (1u << memory_type_idx))) { + continue; + } + if (!(mem_req.memoryTypeBits & (1u << memory_type_idx))) { + continue; + } - for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) { - try { - buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info }); - done = true; + vk::MemoryType memory_type = mem_props.memoryTypes[memory_type_idx]; + // check for visible+coherent+cached. Other flags (e.g. devicelocal) are allowed + if ((memory_type.propertyFlags & property_flags) == property_flags) { + property_flags = memory_type.propertyFlags; break; - } catch (const vk::SystemError& e) { - // loop and retry - // during last attempt throw the exception - if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) { - device->device.destroyBuffer(buf->buffer); - throw e; - } } } + if (memory_type_idx == 32) { + GGML_LOG_WARN("ggml_vulkan: Memory type for host allocation not found\n"); + device->device.destroyBuffer(buf->buffer); + return {}; + } - if (done) { - break; + buf->memory_property_flags = mem_props.memoryTypes[memory_type_idx].propertyFlags; + try { + vk::ImportMemoryHostPointerInfoEXT import_info; + import_info.handleType = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT; + import_info.pHostPointer = import_ptr; + import_info.setPNext(&mem_flags_info); + buf->device_memory = device->device.allocateMemory({ size, memory_type_idx, &import_info }); + } catch (const vk::SystemError& e) { + } + } else { + for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) { + const auto & req_flags = *it; + + const std::vector memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags); + + if (memory_type_indices.empty()) { + continue; + } + buf->memory_property_flags = req_flags; + + bool done = false; + + for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) { + try { + buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info }); + done = true; + break; + } catch (const vk::SystemError& e) { + // loop and retry + // during last attempt throw the exception + if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) { + device->device.destroyBuffer(buf->buffer); + throw e; + } + } + } + + if (done) { + break; + } } } @@ -2492,8 +2546,12 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std buf->ptr = nullptr; - if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { - buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE); + if (import_ptr) { + buf->ptr = import_ptr; + } else { + if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE); + } } device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0); @@ -4447,6 +4505,8 @@ static vk_device ggml_vk_get_device(size_t idx) { } else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 && getenv("GGML_VK_ENABLE_MEMORY_PRIORITY")) { device->memory_priority = true; + } else if (strcmp("VK_EXT_external_memory_host", properties.extensionName) == 0) { + device->external_memory_host = true; } } @@ -4461,6 +4521,7 @@ static vk_device ggml_vk_get_device(size_t idx) { vk::PhysicalDeviceVulkan12Properties vk12_props; vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props; + vk::PhysicalDeviceExternalMemoryHostPropertiesEXT external_memory_host_props; props2.pNext = &props3; props3.pNext = &subgroup_props; @@ -4500,11 +4561,22 @@ static vk_device ggml_vk_get_device(size_t idx) { last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props; } + if (device->external_memory_host) { + last_struct->pNext = (VkBaseOutStructure *)&external_memory_host_props; + last_struct = (VkBaseOutStructure *)&external_memory_host_props; + } + device->physical_device.getProperties2(&props2); device->properties = props2.properties; device->vendor_id = device->properties.vendorID; device->driver_id = driver_props.driverID; + if (device->driver_id == vk::DriverId::eMoltenvk) { + // Disable external_memory_host until https://github.com/KhronosGroup/MoltenVK/pull/2622 + // is available in the Vulkan SDK. + device->external_memory_host = false; + } + // Implementing the async backend interfaces seems broken on older Intel HW, // see https://github.com/ggml-org/llama.cpp/issues/17302. device->support_async = (device->vendor_id != VK_VENDOR_ID_INTEL || @@ -4586,6 +4658,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated; + device->min_imported_host_pointer_alignment = external_memory_host_props.minImportedHostPointerAlignment; + device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations))); std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); @@ -4717,6 +4791,10 @@ static vk_device ggml_vk_get_device(size_t idx) { device_extensions.push_back("VK_KHR_pipeline_executable_properties"); } + if (device->external_memory_host) { + device_extensions.push_back("VK_EXT_external_memory_host"); + } + vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); device->pipeline_executable_properties_support = pipeline_executable_properties_support; @@ -14773,6 +14851,51 @@ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggm VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize"); } +static vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size) { + if (!device->external_memory_host) { + return {}; + } + + uintptr_t uptr = reinterpret_cast(ptr); + if (uptr & (device->min_imported_host_pointer_alignment - 1)) { + return {}; + } + if (size & (device->min_imported_host_pointer_alignment - 1)) { + return {}; + } + + const vk::MemoryPropertyFlags property_flags = vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached; + + vk_buffer buf {}; + try { + buf = ggml_vk_create_buffer(device, size, { property_flags }, ptr); + } catch (vk::SystemError& e) { + GGML_LOG_WARN("ggml_vulkan: Failed ggml_vk_create_buffer (%s)\n", e.what()); + } + + return buf; +} + +static ggml_backend_buffer_t ggml_backend_vk_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + VK_LOG_DEBUG("ggml_backend_vk_device_buffer_from_host_ptr(backend=" << dev << ", ptr=" << ptr << ", size=" << size << ")"); + GGML_UNUSED(max_tensor_size); + + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + + vk_buffer buf = ggml_vk_buffer_from_host_ptr(device, ptr, size); + + if (!buf) { + return {}; + } + + ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(device, std::move(buf), device->name); + + ggml_backend_buffer_t ret = ggml_backend_buffer_init(ggml_backend_vk_device_get_buffer_type(dev), ggml_backend_vk_buffer_interface, bufctx, size); + + return ret; +} + static const struct ggml_backend_device_i ggml_backend_vk_device_i = { /* .get_name = */ ggml_backend_vk_device_get_name, /* .get_description = */ ggml_backend_vk_device_get_description, @@ -14782,7 +14905,7 @@ static const struct ggml_backend_device_i ggml_backend_vk_device_i = { /* .init_backend = */ ggml_backend_vk_device_init, /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type, /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type, - /* .buffer_from_host_ptr = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_vk_device_buffer_from_host_ptr, /* .supports_op = */ ggml_backend_vk_device_supports_op, /* .supports_buft = */ ggml_backend_vk_device_supports_buft, /* .offload_op = */ ggml_backend_vk_device_offload_op, diff --git a/scripts/snapdragon/adb/run-bench.sh b/scripts/snapdragon/adb/run-bench.sh index b2e651e749..1a7d8c9fd6 100755 --- a/scripts/snapdragon/adb/run-bench.sh +++ b/scripts/snapdragon/adb/run-bench.sh @@ -16,8 +16,14 @@ model="Llama-3.2-3B-Instruct-Q4_0.gguf" device="HTP0" [ "$D" != "" ] && device="$D" -verbose="" -[ "$V" != "" ] && verbose="$V" +verbose= +[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" cli_opts="$cli_opts -v" + +experimental= +[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E" + +profile= +[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" cli_opts="$cli_opts -v" opmask= [ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK" @@ -34,7 +40,7 @@ adb $adbserial shell " \ cd $basedir; \ LD_LIBRARY_PATH=$basedir/$branch/lib \ ADSP_LIBRARY_PATH=$basedir/$branch/lib \ - $ndev $nhvx $opmask ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \ + $ndev $nhvx $opmask $verbose $experimental $profile ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \ --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ - --batch-size 128 -ngl 99 $@ \ + --batch-size 128 -ngl 99 $cli_opts $@ \ " diff --git a/src/llama.cpp b/src/llama.cpp index 98fb770844..0162ae8d58 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -359,6 +359,11 @@ static void llama_params_fit_impl( // for the first partial layer varying parts can overflow, all further layers use LAYER_FRACTION_MOE: layer_fraction_t overflow_type = LAYER_FRACTION_MOE; + + uint32_t n_full() const { + assert(n_layer >= n_part); + return n_layer - n_part; + } }; const size_t ntbo = llama_max_tensor_buft_overrides(); @@ -382,7 +387,7 @@ static void llama_params_fit_impl( size_t itbo = 0; for (size_t id = 0; id < nd; id++) { - il0 += ngl_per_device[id].n_layer - ngl_per_device[id].n_part; + il0 += ngl_per_device[id].n_full(); for (uint32_t il = il0; il < il0 + ngl_per_device[id].n_part; il++) { if (itbo + 1 >= ntbo) { tensor_buft_overrides[itbo].pattern = nullptr; @@ -393,7 +398,7 @@ static void llama_params_fit_impl( + std::to_string(ntbo) + " is insufficient for model"); } tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE); - tensor_buft_overrides[itbo].buft = overflow_bufts[id]; + tensor_buft_overrides[itbo].buft = il == il0 ? overflow_bufts[id] : ggml_backend_cpu_buffer_type(); itbo++; } il0 += ngl_per_device[id].n_part; @@ -468,20 +473,14 @@ static void llama_params_fit_impl( LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB); } - std::vector overflow_bufts; // which bufts the partial layers of a device overflow to: + std::vector overflow_bufts; // which bufts the first partial layer of a device overflows to: overflow_bufts.reserve(nd); - for (size_t id = 0; id < nd - 1; ++id) { - overflow_bufts.push_back(ggml_backend_dev_buffer_type(devs[id + 1])); + for (size_t id = 0; id < nd; id++) { + overflow_bufts.push_back(ggml_backend_cpu_buffer_type()); } - overflow_bufts.push_back(ggml_backend_cpu_buffer_type()); std::vector ngl_per_device(nd); std::vector mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts); - if (hp_nex > 0) { - for (size_t id = 0; id < nd; id++) { - ngl_per_device[id].overflow_type = LAYER_FRACTION_MOE; - } - } // optimize the number of layers per device using the method of false position: // - ngl_per_device has 0 layers for each device, lower bound @@ -512,9 +511,6 @@ static void llama_params_fit_impl( if (mem_high[id] > targets[id]) { assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer); uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer; - if (hp_nex > 0 && size_t(id) == nd - 1) { - delta--; - } LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta); while (delta > 1) { uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); @@ -524,7 +520,8 @@ static void llama_params_fit_impl( std::vector ngl_per_device_test = ngl_per_device; ngl_per_device_test[id].n_layer += step_size; if (hp_nex) { - ngl_per_device_test[id].n_part += step_size; + ngl_per_device_test[id].n_part += size_t(id) == nd - 1 && ngl_per_device_test[id].n_part == 0 ? + step_size - 1 : step_size; // the first layer is the output layer which must always be full } const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); @@ -573,7 +570,7 @@ static void llama_params_fit_impl( assert(id_dense_start < nd); LLAMA_LOG_INFO("%s: converting dense-only layers to full layers and filling them front-to-back with overflow to next device/system memory:\n", __func__); - for (size_t id = 0; id <= id_dense_start; id++) { + for (size_t id = 0; id <= id_dense_start && id_dense_start < nd; id++) { std::vector ngl_per_device_high = ngl_per_device; for (size_t jd = id_dense_start; jd < nd; jd++) { const uint32_t n_layer_move = jd < nd - 1 ? ngl_per_device_high[jd].n_layer : ngl_per_device_high[jd].n_layer - 1; @@ -585,12 +582,8 @@ static void llama_params_fit_impl( std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts); if (mem_high[id] > targets[id]) { - assert(ngl_per_device_high[id].n_layer >= ngl_per_device_high[id].n_part); - assert(ngl_per_device[id].n_layer >= ngl_per_device[id].n_part); - assert((ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part) - >= ngl_per_device[id].n_layer - ngl_per_device[id].n_part); - uint32_t delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part) - - (ngl_per_device[id].n_layer - ngl_per_device[id].n_part); + assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full()); + uint32_t delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full(); while (delta > 1) { uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); step_size = std::max(step_size, uint32_t(1)); @@ -606,7 +599,7 @@ static void llama_params_fit_impl( ngl_per_device_test[id].n_layer += n_convert_jd; n_converted_test += n_convert_jd; - if (ngl_per_device_test[id_dense_start_test].n_layer > 0) { + if (ngl_per_device_test[id_dense_start_test].n_part > 0) { break; } } @@ -625,8 +618,8 @@ static void llama_params_fit_impl( LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start_high=%zu\n", __func__, id, ngl_per_device_high[id].n_layer, ngl_per_device_high[id].n_part, id_dense_start_high); } - delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part) - - (ngl_per_device[id].n_layer - ngl_per_device[id].n_part); + assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full()); + delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full(); } } else { ngl_per_device = ngl_per_device_high; @@ -644,14 +637,19 @@ static void llama_params_fit_impl( ngl_per_device_test[id_dense_start_test].n_part--; ngl_per_device_test[id].n_layer++; ngl_per_device_test[id].n_part++; - if (ngl_per_device_test[id_dense_start_test].n_layer == 0) { + if (ngl_per_device_test[id_dense_start_test].n_part == 0) { id_dense_start_test++; } ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP; + std::vector overflow_bufts_test = overflow_bufts; + if (id < nd - 1) { + overflow_bufts_test[id] = ggml_backend_dev_buffer_type(devs[id + 1]); + } LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__); - std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); + std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { ngl_per_device = ngl_per_device_test; + overflow_bufts = overflow_bufts_test; mem = mem_test; id_dense_start = id_dense_start_test; LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", UP), id_dense_start=%zu\n", @@ -659,9 +657,10 @@ static void llama_params_fit_impl( ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE; LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__); - mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); + mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { ngl_per_device = ngl_per_device_test; + overflow_bufts = overflow_bufts_test; mem = mem_test; id_dense_start = id_dense_start_test; LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", GATE), id_dense_start=%zu\n", @@ -670,9 +669,10 @@ static void llama_params_fit_impl( } else { ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN; LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__); - mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); + mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { ngl_per_device = ngl_per_device_test; + overflow_bufts = overflow_bufts_test; mem = mem_test; id_dense_start = id_dense_start_test; LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", ATTN), id_dense_start=%zu\n", @@ -687,6 +687,14 @@ static void llama_params_fit_impl( __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); } + // print info for devices that were not changed during the conversion from dense only to full layers: + for (size_t id = id_dense_start + 1; id < nd; id++) { + const int64_t projected_margin = dmds_full[id].free - mem[id]; + LLAMA_LOG_INFO( + "%s: - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", + __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); + } + set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams); } diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp index 1bbb745e78..e995974a2e 100644 --- a/tests/test-arg-parser.cpp +++ b/tests/test-arg-parser.cpp @@ -127,6 +127,15 @@ int main(void) { assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_SPECULATIVE)); assert(params.speculative.n_max == 123); + // multi-value args (CSV) + argv = {"binary_name", "--lora", "file1.gguf,\"file2,2.gguf\",\"file3\"\"3\"\".gguf\",file4\".gguf"}; + assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON)); + assert(params.lora_adapters.size() == 4); + assert(params.lora_adapters[0].path == "file1.gguf"); + assert(params.lora_adapters[1].path == "file2,2.gguf"); + assert(params.lora_adapters[2].path == "file3\"3\".gguf"); + assert(params.lora_adapters[3].path == "file4\".gguf"); + // skip this part on windows, because setenv is not supported #ifdef _WIN32 printf("test-arg-parser: skip on windows build\n"); diff --git a/tools/mtmd/mtmd-audio.cpp b/tools/mtmd/mtmd-audio.cpp index e99101184b..e8eef035ff 100644 --- a/tools/mtmd/mtmd-audio.cpp +++ b/tools/mtmd/mtmd-audio.cpp @@ -9,207 +9,250 @@ #include #include -// most of the code here is copied from whisper.cpp +// some of the code here is copied from whisper.cpp constexpr bool DEBUG = false; -struct mtmd_audio_mel_filters { - int32_t n_mel; - int32_t n_fft; +void mtmd_audio_cache::fill_sin_cos_table(int n) { + sin_vals.resize(n); + cos_vals.resize(n); + for (int i = 0; i < n; i++) { + double theta = (2 * M_PI * i) / n; + sin_vals[i] = sinf(theta); + cos_vals[i] = cosf(theta); + } +} - std::vector data; -}; +void mtmd_audio_cache::fill_hann_window(int length, bool periodic) { + hann_window.resize(length); + int offset = -1; + if (periodic) { + offset = 0; + } + for (int i = 0; i < length; i++) { + hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + } +} -// note: this global cache is shared among all preprocessors -// if we want to use multiple preprocessors at the same time, -// we will need to enclose it in the preprocessor class in the future -static struct mtmd_audio_global_cache { - // precomputed sin/cos table for FFT - std::vector sin_vals; - std::vector cos_vals; - - // hann window - std::vector hann_window; - - // mel filter bank - mtmd_audio_mel_filters filters; - - void fill_sin_cos_table(int n) { - sin_vals.resize(n); - cos_vals.resize(n); - for (int i = 0; i < n; i++) { - double theta = (2 * M_PI * i) / n; - sin_vals[i] = sinf(theta); - cos_vals[i] = cosf(theta); - } +void mtmd_audio_cache::fill_mel_filterbank_matrix(int n_mel, + int n_fft, + int sample_rate, + float fmin, + float fmax, + bool slaney_area_norm, + float scale) { + GGML_ASSERT(n_mel > 0 && n_fft > 1); + if (fmax <= 0.0f) { + fmax = 0.5f * sample_rate; } - void fill_hann_window(int length, bool periodic) { - hann_window.resize(length); - int offset = -1; - if (periodic) { - offset = 0; - } - for (int i = 0; i < length; i++) { - hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); - } + // Slaney scale (matches librosa default) + const double min_log_hz = 1000.0; + const double lin_slope = 3 / 200.; + const double min_log_mel = min_log_hz * lin_slope; + const double log_step = log(6.4) / 27.0; + auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double { + return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step; + }; + auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double { + return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step); + }; + + // infer N_fft from n_fft_bins + const double bin_hz_step = double(sample_rate) / double(n_fft); + + // mel grid: n_mel + 2 edges + const double m_lo = hz_to_mel(fmin); + const double m_hi = hz_to_mel(fmax); + std::vector mel_pts(n_mel + 2); + for (int i = 0; i < n_mel + 2; ++i) { + mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1)); } - // Build mel filterbank matrix [n_mel × n_fft_bins] at runtime. - // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257. - void fill_mel_filterbank_matrix( - int n_mel, - int n_fft, - int sample_rate, // e.g. 16000 - float fmin = 0.0f, // e.g. 0.0 - float fmax = -1.0f, // e.g. sr/2; pass -1 for auto - bool slaney_area_norm = true, - float scale = 1.0f // optional extra scaling; use 1.0f/1000.0f to mimic your code - ) { - GGML_ASSERT(n_mel > 0 && n_fft > 1); - if (fmax <= 0.0f) { - fmax = 0.5f * sample_rate; - } + // convert to Hz + std::vector hz_pts(n_mel + 2); + for (int i = 0; i < n_mel + 2; ++i) { + hz_pts[i] = mel_to_hz(mel_pts[i]); + } - // Slaney scale (matches librosa default) - const double min_log_hz = 1000.0; - const double lin_slope = 3 / 200.; - const double min_log_mel = min_log_hz * lin_slope; - const double log_step = log(6.4) / 27.0; - auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double { - return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step; - }; - auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double { - return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step); - }; + const int n_fft_bins = n_fft / 2 + 1; - // infer N_fft from n_fft_bins - const double bin_hz_step = double(sample_rate) / double(n_fft); + // filterbank + std::vector out(n_mel * n_fft_bins, 0); + for (int m = 0; m < n_mel; ++m) { + const double f_left = hz_pts[m]; + const double f_center = hz_pts[m + 1]; + const double f_right = hz_pts[m + 2]; - // mel grid: n_mel + 2 edges - const double m_lo = hz_to_mel(fmin); - const double m_hi = hz_to_mel(fmax); - std::vector mel_pts(n_mel + 2); - for (int i = 0; i < n_mel + 2; ++i) { - mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1)); - } + const double denom_l = std::max(1e-30, f_center - f_left); + const double denom_r = std::max(1e-30, f_right - f_center); + const double enorm = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0; - // convert to Hz - std::vector hz_pts(n_mel + 2); - for (int i = 0; i < n_mel + 2; ++i) { - hz_pts[i] = mel_to_hz(mel_pts[i]); - } - - const int n_fft_bins = n_fft / 2 + 1; - - // filterbank - std::vector out(n_mel * n_fft_bins, 0); - for (int m = 0; m < n_mel; ++m) { - const double f_left = hz_pts[m]; - const double f_center = hz_pts[m + 1]; - const double f_right = hz_pts[m + 2]; - - const double denom_l = std::max(1e-30, f_center - f_left); - const double denom_r = std::max(1e-30, f_right - f_center); - const double enorm = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0; - - for (int k = 0; k < n_fft_bins; ++k) { - const double f = k * bin_hz_step; - double w = 0.0; - if (f >= f_left && f <= f_center) { - w = (f - f_left) / denom_l; - } else if (f > f_center && f <= f_right) { - w = (f_right - f) / denom_r; - } - out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale); + for (int k = 0; k < n_fft_bins; ++k) { + const double f = k * bin_hz_step; + double w = 0.0; + if (f >= f_left && f <= f_center) { + w = (f - f_left) / denom_l; + } else if (f > f_center && f <= f_right) { + w = (f_right - f) / denom_r; } + out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale); } + } - filters.n_mel = n_mel; - filters.n_fft = n_fft; - filters.data = std::move(out); + filters.n_mel = n_mel; + filters.n_fft = n_fft; + filters.data = std::move(out); - if (DEBUG) { // debug - for (size_t i = 0; i < filters.data.size(); ++i) { - if (filters.data[i] != 0.0f) { - printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f); - } + if (DEBUG) { // debug + for (size_t i = 0; i < filters.data.size(); ++i) { + if (filters.data[i] != 0.0f) { + printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f); } } } -} g_cache; +} -// naive Discrete Fourier Transform -// input is real-valued -// output is complex-valued -static void dft(const float * in, int N, float * out) { - const int n_sin_cos_vals = g_cache.sin_vals.size(); - const int sin_cos_step = n_sin_cos_vals / N; +// Unified DFT implementation for both forward and inverse transforms +// Template parameters: +// Inverse: false = DFT with exp(-2πi·k·n/N), no scaling +// true = IDFT with exp(+2πi·k·n/N), scales by 1/N +// RealInput: true = input is real-valued (stride 1), avoids imaginary computations +// false = input is complex-valued (interleaved real/imag, stride 2) +template +static void dft_impl(const mtmd_audio_cache & cache, const float * in, int N, float * out) { + const int n_sin_cos_vals = cache.sin_vals.size(); + const int sin_cos_step = n_sin_cos_vals / N; + + constexpr float sign = Inverse ? 1.0f : -1.0f; + const float scale = Inverse ? (1.0f / N) : 1.0f; for (int k = 0; k < N; k++) { float re = 0; float im = 0; for (int n = 0; n < N; n++) { - int idx = (k * n * sin_cos_step) % (n_sin_cos_vals); // t = 2*M_PI*k*n/N - re += in[n] * g_cache.cos_vals[idx]; // cos(t) - im -= in[n] * g_cache.sin_vals[idx]; // sin(t) + int idx = (k * n * sin_cos_step) % n_sin_cos_vals; + float cos_val = cache.cos_vals[idx]; + float sin_val = cache.sin_vals[idx]; + + if constexpr (RealInput) { + // Real input: in_im = 0, simplifies to: + // re += in_re * cos_val + // im += sign * in_re * sin_val + float in_re = in[n]; + re += in_re * cos_val; + im += sign * in_re * sin_val; + } else { + float in_re = in[n * 2 + 0]; + float in_im = in[n * 2 + 1]; + // (a + bi) * (cos + sign*i*sin) = (a*cos - sign*b*sin) + (sign*a*sin + b*cos)i + re += in_re * cos_val - sign * in_im * sin_val; + im += sign * in_re * sin_val + in_im * cos_val; + } } - out[k*2 + 0] = re; - out[k*2 + 1] = im; + out[k * 2 + 0] = re * scale; + out[k * 2 + 1] = im * scale; } } -// Cooley-Tukey FFT -// poor man's implementation - use something better -// input is real-valued -// output is complex-valued -static void fft(float * in, int N, float * out) { - const int n_sin_cos_vals = g_cache.sin_vals.size(); +// Cooley-Tukey FFT/IFFT unified implementation +// Template parameters: +// Inverse: false = FFT with exp(-2πi·k/N), no scaling +// true = IFFT with exp(+2πi·k/N), scales by 0.5 at each level +// RealInput: true = input is real-valued (stride 1) +// false = input is complex-valued (interleaved real/imag, stride 2) +template +static void fft_impl(const mtmd_audio_cache & cache, float * in, int N, float * out) { + const int n_sin_cos_vals = cache.sin_vals.size(); + if (N == 1) { out[0] = in[0]; - out[1] = 0; + if constexpr (RealInput) { + out[1] = 0.0f; + } else { + out[1] = in[1]; + } return; } const int half_N = N / 2; - if (N - half_N*2 == 1) { - dft(in, N, out); + if (N - half_N * 2 == 1) { + // Odd N: fall back to DFT + dft_impl(cache, in, N, out); return; } - float* even = in + N; - for (int i = 0; i < half_N; ++i) { - even[i]= in[2*i]; - } - float* even_fft = out + 2 * N; - fft(even, half_N, even_fft); + // Split into even and odd + if constexpr (RealInput) { + // Real input: stride is 1, copy only real values + float * even = in + N; + for (int i = 0; i < half_N; ++i) { + even[i] = in[2 * i]; + } + float * even_fft = out + 2 * N; + fft_impl(cache, even, half_N, even_fft); - float* odd = even; - for (int i = 0; i < half_N; ++i) { - odd[i] = in[2*i + 1]; + float * odd = even; + for (int i = 0; i < half_N; ++i) { + odd[i] = in[2 * i + 1]; + } + float * odd_fft = even_fft + N; + fft_impl(cache, odd, half_N, odd_fft); + } else { + // Complex input: stride is 2, copy complex pairs + float * even = in + N * 2; + for (int i = 0; i < half_N; ++i) { + even[i * 2 + 0] = in[2 * i * 2 + 0]; + even[i * 2 + 1] = in[2 * i * 2 + 1]; + } + float * even_fft = out + 2 * N; + fft_impl(cache, even, half_N, even_fft); + + float * odd = even; + for (int i = 0; i < half_N; ++i) { + odd[i * 2 + 0] = in[(2 * i + 1) * 2 + 0]; + odd[i * 2 + 1] = in[(2 * i + 1) * 2 + 1]; + } + float * odd_fft = even_fft + N; + fft_impl(cache, odd, half_N, odd_fft); } - float* odd_fft = even_fft + N; - fft(odd, half_N, odd_fft); + + float * even_fft = out + 2 * N; + float * odd_fft = even_fft + N; const int sin_cos_step = n_sin_cos_vals / N; + + constexpr float sign = Inverse ? 1.0f : -1.0f; + constexpr float scale = Inverse ? 0.5f : 1.0f; + for (int k = 0; k < half_N; k++) { - int idx = k * sin_cos_step; // t = 2*M_PI*k/N - float re = g_cache.cos_vals[idx]; // cos(t) - float im = -g_cache.sin_vals[idx]; // sin(t) + int idx = k * sin_cos_step; // t = 2*M_PI*k/N + float re = cache.cos_vals[idx]; + float im = sign * cache.sin_vals[idx]; - float re_odd = odd_fft[2*k + 0]; - float im_odd = odd_fft[2*k + 1]; + float re_odd = odd_fft[2 * k + 0]; + float im_odd = odd_fft[2 * k + 1]; - out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; - out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; + out[2 * k + 0] = scale * (even_fft[2 * k + 0] + re * re_odd - im * im_odd); + out[2 * k + 1] = scale * (even_fft[2 * k + 1] + re * im_odd + im * re_odd); - out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; - out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; + out[2 * (k + half_N) + 0] = scale * (even_fft[2 * k + 0] - re * re_odd + im * im_odd); + out[2 * (k + half_N) + 1] = scale * (even_fft[2 * k + 1] - re * im_odd - im * re_odd); } } +// Forward FFT for real input (used by mel spectrogram) +static void fft(const mtmd_audio_cache & cache, float * in, int N, float * out) { + fft_impl(cache, in, N, out); +} + +// Inverse FFT for complex input +static void ifft(const mtmd_audio_cache & cache, float * in, int N, float * out) { + fft_impl(cache, in, N, out); +} + struct filter_params { int32_t n_mel; int32_t n_fft_bins; @@ -222,20 +265,27 @@ struct filter_params { bool norm_per_feature = false; }; -static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples, - int n_samples, int frame_size, int frame_step, int n_threads, - const filter_params & params, mtmd_audio_mel & out) { +static void log_mel_spectrogram_worker_thread(int ith, + const float * hann, + const std::vector & samples, + int n_samples, + int frame_size, + int frame_step, + int n_threads, + const filter_params & params, + const mtmd_audio_cache & cache, + mtmd_audio_mel & out) { std::vector fft_in(frame_size * 2, 0.0); std::vector fft_out(frame_size * 2 * 2 * 2); int n_fft_bins = params.n_fft_bins; int i = ith; - const auto & filters = g_cache.filters; + const auto & filters = cache.filters; // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist GGML_ASSERT(n_fft_bins == 1 + (frame_size / 2)); - GGML_ASSERT(g_cache.sin_vals.size() == g_cache.cos_vals.size()); + GGML_ASSERT(cache.sin_vals.size() == cache.cos_vals.size()); // calculate FFT only when fft_in are not all zero for (; i < std::min(n_samples / frame_step + 1, out.n_len); i += n_threads) { const int offset = i * frame_step; @@ -251,7 +301,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const } // FFT - fft(fft_in.data(), frame_size, fft_out.data()); + fft(cache, fft_in.data(), frame_size, fft_out.data()); // Calculate modulus^2 of complex numbers // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. @@ -298,6 +348,7 @@ static bool log_mel_spectrogram( const int n_samples_in, const int n_threads, const filter_params & params, + const mtmd_audio_cache & cache, mtmd_audio_mel & out) { //const int64_t t_start_us = ggml_time_us(); @@ -305,9 +356,9 @@ static bool log_mel_spectrogram( int n_samples = n_samples_in; // Hann window - const float * hann = g_cache.hann_window.data(); - const int frame_size = (params.n_fft_bins - 1) * 2; - const int frame_step = params.hop_length; + const float * hann = cache.hann_window.data(); + const int frame_size = (params.n_fft_bins - 1) * 2; + const int frame_step = params.hop_length; // Padding std::vector samples_padded; @@ -335,9 +386,9 @@ static bool log_mel_spectrogram( // preemphasis if (params.preemph) { - const int pad_amount = frame_size / 2; + const int pad_amount = frame_size / 2; const float preemph = 0.97f; - float prev = samples_padded[pad_amount]; + float prev = samples_padded[pad_amount]; for (int i = pad_amount + 1; i + pad_amount < n_samples; ++i) { float cur = samples_padded[i]; samples_padded[i] = cur - preemph * prev; @@ -372,14 +423,14 @@ static bool log_mel_spectrogram( { std::vector workers(n_threads - 1); for (int iw = 0; iw < n_threads - 1; ++iw) { - workers[iw] = std::thread( - log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), - n_samples, frame_size, frame_step, n_threads, - std::cref(params), std::ref(out)); + workers[iw] = + std::thread(log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), n_samples, + frame_size, frame_step, n_threads, std::cref(params), std::cref(cache), std::ref(out)); } // main thread - log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params, out); + log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params, + cache, out); for (int iw = 0; iw < n_threads - 1; ++iw) { workers[iw].join(); } @@ -404,7 +455,7 @@ static bool log_mel_spectrogram( for (int j = 0; j < effective_n_len; ++j) { auto &value = out.data[i * out.n_len + j]; - value = (value - mean) / mstd; + value = (value - mean) / mstd; } // pad the rest with zeros @@ -450,18 +501,14 @@ static bool log_mel_spectrogram( // void mtmd_audio_preprocessor_whisper::initialize() { - g_cache.fill_sin_cos_table(hparams.audio_n_fft); - g_cache.fill_hann_window(hparams.audio_window_len, true); - g_cache.fill_mel_filterbank_matrix( - hparams.n_mel_bins, - hparams.audio_n_fft, - hparams.audio_sample_rate); + cache.fill_sin_cos_table(hparams.audio_n_fft); + cache.fill_hann_window(hparams.audio_window_len, true); + cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate); } -bool mtmd_audio_preprocessor_whisper::preprocess( - const float * samples, - size_t n_samples, - std::vector & output) { +bool mtmd_audio_preprocessor_whisper::preprocess(const float * samples, + size_t n_samples, + std::vector & output) { if (n_samples == 0) { // empty audio return false; @@ -471,7 +518,7 @@ bool mtmd_audio_preprocessor_whisper::preprocess( // if input is too short, pad with zeros // this is to avoid potential issues with stage1/2 padding in log_mel_spectrogram // TODO: maybe handle this better - size_t min_samples = (size_t)hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin + size_t min_samples = (size_t) hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin if (n_samples < min_samples) { smpl.resize(min_samples, 0.0f); std::memcpy(smpl.data(), samples, n_samples * sizeof(float)); @@ -486,22 +533,19 @@ bool mtmd_audio_preprocessor_whisper::preprocess( params.hop_length = hparams.audio_hop_len; params.sample_rate = hparams.audio_sample_rate; params.center_padding = false; - params.preemph = 0.0f; // disabled + params.preemph = 0.0f; // disabled params.use_natural_log = false; params.norm_per_feature = false; - // make sure the global cache is initialized - GGML_ASSERT(!g_cache.sin_vals.empty()); - GGML_ASSERT(!g_cache.cos_vals.empty()); - GGML_ASSERT(!g_cache.filters.data.empty()); + // make sure the cache is initialized + GGML_ASSERT(!cache.sin_vals.empty()); + GGML_ASSERT(!cache.cos_vals.empty()); + GGML_ASSERT(!cache.filters.data.empty()); mtmd_audio_mel out_full; - bool ok = log_mel_spectrogram( - samples, - n_samples, - 4, // n_threads - params, - out_full); + bool ok = log_mel_spectrogram(samples, n_samples, + 4, // n_threads + params, cache, out_full); if (!ok) { return false; } @@ -512,21 +556,21 @@ bool mtmd_audio_preprocessor_whisper::preprocess( printf("output: n_mel = %d, n_len = %d\n", out_full.n_mel, out_full.n_len); } const size_t frames_per_chunk = 3000; - GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk); - for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) { - int n_len = std::min(frames_per_chunk, (size_t)out_full.n_len - off); - if ((size_t)n_len < frames_per_chunk) { - break; // last uncomplete chunk will always be a padded chunk, safe to ignore + GGML_ASSERT((size_t) out_full.n_len > frames_per_chunk); + for (size_t off = 0; off < (size_t) out_full.n_len; off += frames_per_chunk) { + int n_len = std::min(frames_per_chunk, (size_t) out_full.n_len - off); + if ((size_t) n_len < frames_per_chunk) { + break; // last uncomplete chunk will always be a padded chunk, safe to ignore } mtmd_audio_mel out_chunk; out_chunk.n_len = n_len; out_chunk.n_mel = out_full.n_mel; - out_chunk.n_len_org = out_full.n_mel; // unused + out_chunk.n_len_org = out_full.n_mel; // unused out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len); for (int i = 0; i < out_full.n_mel; i++) { - auto src = out_full.data.begin() + i*out_full.n_len + off; + auto src = out_full.data.begin() + i * out_full.n_len + off; out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk); } @@ -541,18 +585,14 @@ bool mtmd_audio_preprocessor_whisper::preprocess( // void mtmd_audio_preprocessor_conformer::initialize() { - g_cache.fill_sin_cos_table(hparams.audio_n_fft); - g_cache.fill_hann_window(hparams.audio_window_len, true); - g_cache.fill_mel_filterbank_matrix( - hparams.n_mel_bins, - hparams.audio_n_fft, - hparams.audio_sample_rate); + cache.fill_sin_cos_table(hparams.audio_n_fft); + cache.fill_hann_window(hparams.audio_window_len, true); + cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate); } -bool mtmd_audio_preprocessor_conformer::preprocess( - const float * samples, - size_t n_samples, - std::vector & output) { +bool mtmd_audio_preprocessor_conformer::preprocess(const float * samples, + size_t n_samples, + std::vector & output) { // empty audio if (n_samples == 0) { return false; @@ -569,18 +609,15 @@ bool mtmd_audio_preprocessor_conformer::preprocess( params.use_natural_log = true; params.norm_per_feature = true; - // make sure the global cache is initialized - GGML_ASSERT(!g_cache.sin_vals.empty()); - GGML_ASSERT(!g_cache.cos_vals.empty()); - GGML_ASSERT(!g_cache.filters.data.empty()); + // make sure the cache is initialized + GGML_ASSERT(!cache.sin_vals.empty()); + GGML_ASSERT(!cache.cos_vals.empty()); + GGML_ASSERT(!cache.filters.data.empty()); mtmd_audio_mel out_full; - bool ok = log_mel_spectrogram( - samples, - n_samples, - 4, // n_threads - params, - out_full); + bool ok = log_mel_spectrogram(samples, n_samples, + 4, // n_threads + params, cache, out_full); if (!ok) { return false; } @@ -588,3 +625,106 @@ bool mtmd_audio_preprocessor_conformer::preprocess( output.push_back(std::move(out_full)); return true; } + +// +// mtmd_audio_streaming_istft implementation +// + +mtmd_audio_streaming_istft::mtmd_audio_streaming_istft(int n_fft, int hop_length) : + n_fft(n_fft), + hop_length(hop_length), + n_fft_bins(n_fft / 2 + 1), + overlap_buffer(n_fft, 0.0f), + window_sum_buffer(n_fft, 0.0f), + padding_to_remove((n_fft - hop_length) / 2), + ifft_in(n_fft * 2 * 4, 0.0f), // extra space for recursive IFFT + ifft_out(n_fft * 2 * 4, 0.0f) { + cache.fill_sin_cos_table(n_fft); + cache.fill_hann_window(n_fft, true); +} + +void mtmd_audio_streaming_istft::reset() { + std::fill(overlap_buffer.begin(), overlap_buffer.end(), 0.0f); + std::fill(window_sum_buffer.begin(), window_sum_buffer.end(), 0.0f); + padding_to_remove = (n_fft - hop_length) / 2; +} + +std::vector mtmd_audio_streaming_istft::process_frame(const float * frame_spectrum) { + std::vector output(hop_length); + + // copy frequencies + for (int j = 0; j < n_fft_bins; j++) { + ifft_in[j * 2 + 0] = frame_spectrum[j * 2 + 0]; + ifft_in[j * 2 + 1] = frame_spectrum[j * 2 + 1]; + } + + // mirror negative frequencies + for (int j = 1; j < n_fft_bins - 1; j++) { + int mirror_idx = n_fft - j; + ifft_in[mirror_idx * 2 + 0] = ifft_in[j * 2 + 0]; + ifft_in[mirror_idx * 2 + 1] = -ifft_in[j * 2 + 1]; // conjugate + } + + ifft(cache, ifft_in.data(), n_fft, ifft_out.data()); + + // update window sum and overlap buffer + for (int j = 0; j < n_fft; j++) { + window_sum_buffer[j] += cache.hann_window[j] * cache.hann_window[j]; + overlap_buffer[j] += ifft_out[j * 2] * cache.hann_window[j]; + } + + // extract hop_length samples with normalization + for (int i = 0; i < hop_length; i++) { + if (window_sum_buffer[i] > 1e-8f) { + output[i] = overlap_buffer[i] / window_sum_buffer[i]; + } else { + output[i] = overlap_buffer[i]; + } + } + + // shift buffers left by hop_length + std::copy(overlap_buffer.begin() + hop_length, overlap_buffer.end(), overlap_buffer.begin()); + std::fill(overlap_buffer.end() - hop_length, overlap_buffer.end(), 0.0f); + + std::copy(window_sum_buffer.begin() + hop_length, window_sum_buffer.end(), window_sum_buffer.begin()); + std::fill(window_sum_buffer.end() - hop_length, window_sum_buffer.end(), 0.0f); + + // Remove padding if needed + int to_remove = std::min(padding_to_remove, (int) output.size()); + padding_to_remove -= to_remove; + output.erase(output.begin(), output.begin() + to_remove); + + return output; +} + +std::vector mtmd_audio_streaming_istft::flush() { + std::vector output; + + // Extract remaining samples from overlap buffer + // Continue until we've extracted all meaningful samples + int remaining = n_fft - hop_length; + while (remaining > 0) { + int chunk_size = std::min(remaining, hop_length); + + for (int i = 0; i < chunk_size; i++) { + float sample; + if (window_sum_buffer[i] > 1e-8f) { + sample = overlap_buffer[i] / window_sum_buffer[i]; + } else { + sample = overlap_buffer[i]; + } + output.push_back(sample); + } + + // Shift buffers + std::copy(overlap_buffer.begin() + chunk_size, overlap_buffer.end(), overlap_buffer.begin()); + std::fill(overlap_buffer.end() - chunk_size, overlap_buffer.end(), 0.0f); + + std::copy(window_sum_buffer.begin() + chunk_size, window_sum_buffer.end(), window_sum_buffer.begin()); + std::fill(window_sum_buffer.end() - chunk_size, window_sum_buffer.end(), 0.0f); + + remaining -= chunk_size; + } + + return output; +} diff --git a/tools/mtmd/mtmd-audio.h b/tools/mtmd/mtmd-audio.h index d484c9d030..016c7392e4 100644 --- a/tools/mtmd/mtmd-audio.h +++ b/tools/mtmd/mtmd-audio.h @@ -17,6 +17,38 @@ struct mtmd_audio_mel { std::vector data; }; +struct mtmd_audio_mel_filters { + int32_t n_mel; + int32_t n_fft; + + std::vector data; +}; + +// cache for audio processing, each processor instance owns its own cache +struct mtmd_audio_cache { + std::vector sin_vals; + std::vector cos_vals; + + std::vector hann_window; + + mtmd_audio_mel_filters filters; + + void fill_sin_cos_table(int n); + + void fill_hann_window(int length, bool periodic); + + // Build mel filterbank matrix [n_mel × n_fft_bins] at runtime. + // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257. + void fill_mel_filterbank_matrix(int n_mel, + int n_fft, + int sample_rate, // e.g. 16000 + float fmin = 0.0f, // e.g. 0.0 + float fmax = -1.0f, // e.g. sr/2; pass -1 for auto + bool slaney_area_norm = true, + float scale = 1.0f // optional extra scaling + ); +}; + struct mtmd_audio_preprocessor { const clip_hparams & hparams; @@ -31,10 +63,51 @@ struct mtmd_audio_preprocessor_whisper : mtmd_audio_preprocessor { mtmd_audio_preprocessor_whisper(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {} void initialize() override; bool preprocess(const float * samples, size_t n_samples, std::vector & output) override; + + private: + mtmd_audio_cache cache; }; struct mtmd_audio_preprocessor_conformer : mtmd_audio_preprocessor { mtmd_audio_preprocessor_conformer(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {} void initialize() override; bool preprocess(const float * samples, size_t n_samples, std::vector & output) override; + + private: + mtmd_audio_cache cache; +}; + +// +// streaming ISTFT - converts spectrogram frames back to audio one frame at a time +// +struct mtmd_audio_streaming_istft { + mtmd_audio_streaming_istft(int n_fft, int hop_length); + + // reset streaming state + void reset(); + + // process a single STFT frame (streaming) + // frame_spectrum: [n_fft_bins x 2] interleaved real/imag + // returns: up to hop_length samples + std::vector process_frame(const float * frame_spectrum); + + // flush remaining samples at end of stream + std::vector flush(); + + private: + int n_fft; + int hop_length; + int n_fft_bins; + + // Own cache for output processing + mtmd_audio_cache cache; + + // Streaming state + std::vector overlap_buffer; + std::vector window_sum_buffer; + int padding_to_remove; + + // Working buffers for IFFT + std::vector ifft_in; + std::vector ifft_out; }; diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 6d374131e3..ed4f6546ea 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -814,6 +814,15 @@ json server_task_result_cmpl_final::to_json_anthropic() { msg.content = content; } + // thinking block comes first (Anthropic extended thinking format) + if (!msg.reasoning_content.empty()) { + content_blocks.push_back({ + {"type", "thinking"}, + {"thinking", msg.reasoning_content}, + {"signature", ""} // empty signature for local models (no cryptographic verification) + }); + } + if (!msg.content.empty()) { content_blocks.push_back({ {"type", "text"}, @@ -862,20 +871,57 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() { stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use"; } - bool has_text = !oaicompat_msg.content.empty(); + bool has_thinking = !oaicompat_msg.reasoning_content.empty(); + bool has_text = !oaicompat_msg.content.empty(); size_t num_tool_calls = oaicompat_msg.tool_calls.size(); - bool text_block_started = false; + // content block indices: thinking (0) -> text (0 or 1) -> tool_use (n+) + size_t thinking_block_index = 0; + size_t text_block_index = has_thinking ? 1 : 0; + + bool thinking_block_started = false; + bool text_block_started = false; std::unordered_set tool_calls_started; for (const auto & diff : oaicompat_msg_diffs) { + // handle thinking/reasoning content + if (!diff.reasoning_content_delta.empty()) { + if (!thinking_block_started) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", thinking_block_index}, + {"content_block", { + {"type", "thinking"}, + {"thinking", ""} + }} + }} + }); + thinking_block_started = true; + } + + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", thinking_block_index}, + {"delta", { + {"type", "thinking_delta"}, + {"thinking", diff.reasoning_content_delta} + }} + }} + }); + } + + // handle regular text content if (!diff.content_delta.empty()) { if (!text_block_started) { events.push_back({ {"event", "content_block_start"}, {"data", { {"type", "content_block_start"}, - {"index", 0}, + {"index", text_block_index}, {"content_block", { {"type", "text"}, {"text", ""} @@ -889,7 +935,7 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() { {"event", "content_block_delta"}, {"data", { {"type", "content_block_delta"}, - {"index", 0}, + {"index", text_block_index}, {"delta", { {"type", "text_delta"}, {"text", diff.content_delta} @@ -898,8 +944,9 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() { }); } + // handle tool calls if (diff.tool_call_index != std::string::npos) { - size_t content_block_index = (has_text ? 1 : 0) + diff.tool_call_index; + size_t content_block_index = (has_thinking ? 1 : 0) + (has_text ? 1 : 0) + diff.tool_call_index; if (tool_calls_started.find(diff.tool_call_index) == tool_calls_started.end()) { const auto & full_tool_call = oaicompat_msg.tool_calls[diff.tool_call_index]; @@ -935,18 +982,42 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() { } } + // close content blocks in order + if (has_thinking) { + // Anthropic API requires a signature_delta before closing thinking blocks + // We use an empty signature since we can't generate a cryptographic signature for local models + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", thinking_block_index}, + {"delta", { + {"type", "signature_delta"}, + {"signature", ""} + }} + }} + }); + events.push_back({ + {"event", "content_block_stop"}, + {"data", { + {"type", "content_block_stop"}, + {"index", thinking_block_index} + }} + }); + } + if (has_text) { events.push_back({ {"event", "content_block_stop"}, {"data", { {"type", "content_block_stop"}, - {"index", 0} + {"index", text_block_index} }} }); } for (size_t i = 0; i < num_tool_calls; i++) { - size_t content_block_index = (has_text ? 1 : 0) + i; + size_t content_block_index = (has_thinking ? 1 : 0) + (has_text ? 1 : 0) + i; events.push_back({ {"event", "content_block_stop"}, {"data", { @@ -1154,11 +1225,10 @@ json server_task_result_rerank::to_json() { json server_task_result_cmpl_partial::to_json_anthropic() { json events = json::array(); bool first = (n_decoded == 1); - bool text_block_started = false; + // use member variables to track block state across streaming calls + // (anthropic_thinking_block_started, anthropic_text_block_started) if (first) { - text_block_started = false; - events.push_back({ {"event", "message_start"}, {"data", { @@ -1180,28 +1250,69 @@ json server_task_result_cmpl_partial::to_json_anthropic() { }); } + // content block indices: thinking (0) -> text (0 or 1) -> tool_use (n+) + size_t thinking_block_index = 0; + // use anthropic_has_reasoning (set in update()) to know if ANY reasoning was generated + size_t text_block_index = anthropic_has_reasoning ? 1 : 0; + + // use local copies of streaming state (copied from task_result_state in update()) + // these reflect the state BEFORE this chunk was processed + bool thinking_started = anthropic_thinking_block_started; + bool text_started = anthropic_text_block_started; + for (const auto & diff : oaicompat_msg_diffs) { - if (!diff.content_delta.empty()) { - if (!text_block_started) { + // handle thinking/reasoning content + if (!diff.reasoning_content_delta.empty()) { + if (!thinking_started) { events.push_back({ {"event", "content_block_start"}, {"data", { {"type", "content_block_start"}, - {"index", 0}, + {"index", thinking_block_index}, {"content_block", { - {"type", "text"}, - {"text", ""} + {"type", "thinking"}, + {"thinking", ""} }} }} }); - text_block_started = true; + thinking_started = true; } events.push_back({ {"event", "content_block_delta"}, {"data", { {"type", "content_block_delta"}, - {"index", 0}, + {"index", thinking_block_index}, + {"delta", { + {"type", "thinking_delta"}, + {"thinking", diff.reasoning_content_delta} + }} + }} + }); + } + + // handle regular text content + if (!diff.content_delta.empty()) { + if (!text_started) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", text_block_index}, + {"content_block", { + {"type", "text"}, + {"text", ""} + }} + }} + }); + text_started = true; + } + + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", text_block_index}, {"delta", { {"type", "text_delta"}, {"text", diff.content_delta} @@ -1210,8 +1321,10 @@ json server_task_result_cmpl_partial::to_json_anthropic() { }); } + // handle tool calls if (diff.tool_call_index != std::string::npos) { - size_t content_block_index = (text_block_started ? 1 : 0) + diff.tool_call_index; + // use anthropic_has_reasoning for thinking block count (persists across calls) + size_t content_block_index = (anthropic_has_reasoning ? 1 : 0) + (text_started ? 1 : 0) + diff.tool_call_index; if (!diff.tool_call_delta.name.empty()) { events.push_back({ diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 687770de5e..ead1491182 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -96,6 +96,10 @@ struct task_result_state { std::string generated_text; // append new chunks of generated text here std::vector generated_tool_call_ids; + // for Anthropic API streaming: track content block state across chunks + bool anthropic_thinking_block_started = false; + bool anthropic_text_block_started = false; + task_result_state(const common_chat_syntax & oaicompat_chat_syntax) : oaicompat_chat_syntax(oaicompat_chat_syntax) {} @@ -337,6 +341,12 @@ struct server_task_result_cmpl_partial : server_task_result { std::vector oaicompat_msg_diffs; // to be populated by update() bool is_updated = false; + // for Anthropic API: track if any reasoning content has been generated + bool anthropic_has_reasoning = false; + // Streaming state copied from task_result_state for this chunk + bool anthropic_thinking_block_started = false; + bool anthropic_text_block_started = false; + virtual bool is_stop() override { return false; // in stream mode, partial responses are not considered stop } @@ -346,6 +356,22 @@ struct server_task_result_cmpl_partial : server_task_result { virtual void update(task_result_state & state) override { is_updated = true; state.update_chat_msg(content, true, oaicompat_msg_diffs); + // track if the accumulated message has any reasoning content + anthropic_has_reasoning = !state.chat_msg.reasoning_content.empty(); + + // Copy current state for use in to_json_anthropic() (reflects state BEFORE this chunk) + anthropic_thinking_block_started = state.anthropic_thinking_block_started; + anthropic_text_block_started = state.anthropic_text_block_started; + + // Pre-compute state updates based on diffs (for next chunk) + for (const auto & diff : oaicompat_msg_diffs) { + if (!diff.reasoning_content_delta.empty() && !state.anthropic_thinking_block_started) { + state.anthropic_thinking_block_started = true; + } + if (!diff.content_delta.empty() && !state.anthropic_text_block_started) { + state.anthropic_text_block_started = true; + } + } } json to_json_non_oaicompat(); diff --git a/tools/server/tests/unit/test_compat_anthropic.py b/tools/server/tests/unit/test_compat_anthropic.py index e0a003557e..e16e0235c6 100644 --- a/tools/server/tests/unit/test_compat_anthropic.py +++ b/tools/server/tests/unit/test_compat_anthropic.py @@ -805,3 +805,92 @@ def test_anthropic_vs_openai_different_response_format(): assert "input_tokens" in anthropic_res.body["usage"] assert "completion_tokens" in openai_res.body["usage"] assert "output_tokens" in anthropic_res.body["usage"] + + +# Extended thinking tests with reasoning models + +@pytest.mark.slow +@pytest.mark.parametrize("stream", [False, True]) +def test_anthropic_thinking_with_reasoning_model(stream): + """Test that thinking content blocks are properly returned for reasoning models""" + global server + server = ServerProcess() + server.model_hf_repo = "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF" + server.model_hf_file = "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf" + server.reasoning_format = "deepseek" + server.jinja = True + server.n_ctx = 8192 + server.n_predict = 1024 + server.server_port = 8084 + server.start(timeout_seconds=600) # large model needs time to download + + if stream: + res = server.make_stream_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 1024, + "thinking": { + "type": "enabled", + "budget_tokens": 500 + }, + "messages": [ + {"role": "user", "content": "What is 2+2?"} + ], + "stream": True + }) + + events = list(res) + + # should have thinking content block events + thinking_starts = [e for e in events if + e.get("type") == "content_block_start" and + e.get("content_block", {}).get("type") == "thinking"] + assert len(thinking_starts) > 0, "Should have thinking content_block_start event" + assert thinking_starts[0]["index"] == 0, "Thinking block should be at index 0" + + # should have thinking_delta events + thinking_deltas = [e for e in events if + e.get("type") == "content_block_delta" and + e.get("delta", {}).get("type") == "thinking_delta"] + assert len(thinking_deltas) > 0, "Should have thinking_delta events" + + # should have signature_delta event before thinking block closes (Anthropic API requirement) + signature_deltas = [e for e in events if + e.get("type") == "content_block_delta" and + e.get("delta", {}).get("type") == "signature_delta"] + assert len(signature_deltas) > 0, "Should have signature_delta event for thinking block" + + # should have text block after thinking + text_starts = [e for e in events if + e.get("type") == "content_block_start" and + e.get("content_block", {}).get("type") == "text"] + assert len(text_starts) > 0, "Should have text content_block_start event" + assert text_starts[0]["index"] == 1, "Text block should be at index 1 (after thinking)" + else: + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 1024, + "thinking": { + "type": "enabled", + "budget_tokens": 500 + }, + "messages": [ + {"role": "user", "content": "What is 2+2?"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + content = res.body["content"] + assert len(content) >= 2, "Should have at least thinking and text blocks" + + # first block should be thinking + thinking_blocks = [b for b in content if b.get("type") == "thinking"] + assert len(thinking_blocks) > 0, "Should have thinking content block" + assert "thinking" in thinking_blocks[0], "Thinking block should have 'thinking' field" + assert len(thinking_blocks[0]["thinking"]) > 0, "Thinking content should not be empty" + assert "signature" in thinking_blocks[0], "Thinking block should have 'signature' field (Anthropic API requirement)" + + # should also have text block + text_blocks = [b for b in content if b.get("type") == "text"] + assert len(text_blocks) > 0, "Should have text content block"