Merge branch 'ggml-org:master' into Kimi-Linear

This commit is contained in:
ymcki 2026-01-07 16:32:12 +08:00 committed by GitHub
commit 40f6118192
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 3108 additions and 996 deletions

View File

@ -854,6 +854,54 @@ bool common_arg_utils::is_autoy(const std::string & value) {
return value == "auto" || value == "-1";
}
// Simple CSV parser that handles quoted fields and escaped quotes
// example:
// input: value1,"value, with, commas","value with ""escaped"" quotes",value4
// output: [value1] [value, with, commas] [value with "escaped" quotes] [value4]
static std::vector<std::string> parse_csv_row(const std::string& input) {
std::vector<std::string> fields;
std::string field;
bool in_quotes = false;
for (size_t i = 0; i < input.length(); ++i) {
char ch = input[i];
if (ch == '"') {
if (!in_quotes) {
// start of quoted field (only valid if at beginning of field)
if (!field.empty()) {
// quote appeared in middle of unquoted field, treat as literal
field += '"';
} else {
in_quotes = true; // start
}
} else {
if (i + 1 < input.length() && input[i + 1] == '"') {
// escaped quote: ""
field += '"';
++i; // skip the next quote
} else {
in_quotes = false; // end
}
}
} else if (ch == ',') {
if (in_quotes) {
field += ',';
} else {
fields.push_back(std::move(field));
field.clear();
}
} else {
field += ch;
}
}
// Add the last field
fields.push_back(std::move(field));
return fields;
}
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
// per-example default params
// we define here to make sure it's included in llama-gen-docs
@ -1250,7 +1298,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--in-file"}, "FNAME",
"an input file (use comma-separated values to specify multiple files)",
[](common_params & params, const std::string & value) {
for (const auto & item : string_split<std::string>(value, ',')) {
for (const auto & item : parse_csv_row(value)) {
std::ifstream file(item);
if (!file) {
throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str()));
@ -2002,7 +2050,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--image", "--audio"}, "FILE",
"path to an image or audio file. use with multimodal models, use comma-separated values for multiple files\n",
[](common_params & params, const std::string & value) {
for (const auto & item : string_split<std::string>(value, ',')) {
for (const auto & item : parse_csv_row(value)) {
params.image.emplace_back(item);
}
}
@ -2259,37 +2307,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
));
add_opt(common_arg(
{"--override-kv"}, "KEY=TYPE:VALUE,...",
"advanced option to override model metadata by key. to specify multiple overrides, either use comma-separated or repeat this argument.\n"
"advanced option to override model metadata by key. to specify multiple overrides, either use comma-separated values.\n"
"types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false,tokenizer.ggml.add_eos_token=bool:false",
[](common_params & params, const std::string & value) {
std::vector<std::string> kv_overrides;
std::string current;
bool escaping = false;
for (const char c : value) {
if (escaping) {
current.push_back(c);
escaping = false;
} else if (c == '\\') {
escaping = true;
} else if (c == ',') {
kv_overrides.push_back(current);
current.clear();
} else {
current.push_back(c);
}
}
if (escaping) {
current.push_back('\\');
}
kv_overrides.push_back(current);
for (const auto & kv_override : kv_overrides) {
if (!string_parse_kv_override(kv_override.c_str(), params.kv_overrides)) {
throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", kv_override.c_str()));
for (const auto & item : parse_csv_row(value)) {
if (!string_parse_kv_override(item.c_str(), params.kv_overrides)) {
throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", item.c_str()));
}
}
}
@ -2306,7 +2329,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--lora"}, "FNAME",
"path to LoRA adapter (use comma-separated values to load multiple adapters)",
[](common_params & params, const std::string & value) {
for (const auto & item : string_split<std::string>(value, ',')) {
for (const auto & item : parse_csv_row(value)) {
params.lora_adapters.push_back({ item, 1.0, "", "", nullptr });
}
}
@ -2317,7 +2340,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"path to LoRA adapter with user defined scaling (format: FNAME:SCALE,...)\n"
"note: use comma-separated values",
[](common_params & params, const std::string & value) {
for (const auto & item : string_split<std::string>(value, ',')) {
for (const auto & item : parse_csv_row(value)) {
auto parts = string_split<std::string>(item, ':');
if (parts.size() != 2) {
throw std::invalid_argument("lora-scaled format: FNAME:SCALE");
@ -2331,7 +2354,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--control-vector"}, "FNAME",
"add a control vector\nnote: use comma-separated values to add multiple control vectors",
[](common_params & params, const std::string & value) {
for (const auto & item : string_split<std::string>(value, ',')) {
for (const auto & item : parse_csv_row(value)) {
params.control_vectors.push_back({ 1.0f, item, });
}
}
@ -2341,7 +2364,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"add a control vector with user defined scaling SCALE\n"
"note: use comma-separated values (format: FNAME:SCALE,...)",
[](common_params & params, const std::string & value) {
for (const auto & item : string_split<std::string>(value, ',')) {
for (const auto & item : parse_csv_row(value)) {
auto parts = string_split<std::string>(item, ':');
if (parts.size() != 2) {
throw std::invalid_argument("control-vector-scaled format: FNAME:SCALE");
@ -2439,7 +2462,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--context-file"}, "FNAME",
"file to load context from (use comma-separated values to specify multiple files)",
[](common_params & params, const std::string & value) {
for (const auto & item : string_split<std::string>(value, ',')) {
for (const auto & item : parse_csv_row(value)) {
std::ifstream file(item, std::ios::binary);
if (!file) {
throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str()));
@ -2675,9 +2698,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING"));
add_opt(common_arg(
{"--api-key"}, "KEY",
"API key to use for authentication (default: none)",
"API key to use for authentication, multiple keys can be provided as a comma-separated list (default: none)",
[](common_params & params, const std::string & value) {
params.api_keys.push_back(value);
for (const auto & key : parse_csv_row(value)) {
if (!key.empty()) {
params.api_keys.push_back(key);
}
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_API_KEY"));
add_opt(common_arg(
@ -2691,7 +2718,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
std::string key;
while (std::getline(key_file, key)) {
if (!key.empty()) {
params.api_keys.push_back(key);
params.api_keys.push_back(key);
}
}
key_file.close();
@ -2713,7 +2740,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSL_CERT_FILE"));
add_opt(common_arg(
{"--chat-template-kwargs"}, "STRING",
string_format("sets additional params for the json template parser"),
"sets additional params for the json template parser, must be a valid json object string, e.g. '{\"key1\":\"value1\",\"key2\":\"value2\"}'",
[](common_params & params, const std::string & value) {
auto parsed = json::parse(value);
for (const auto & item : parsed.items()) {

View File

@ -1963,7 +1963,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context & ctx, ggml_tensor *
acl_tensor_ptr acl_weight_tensor;
// Only check env once.
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
if (weight_to_nz && is_matmul_weight(weight)) {
acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
} else {

View File

@ -103,7 +103,7 @@ const ggml_cann_device_info & ggml_cann_info();
void ggml_cann_set_device(int32_t device);
int32_t ggml_cann_get_device();
std::optional<std::string> get_env(const std::string & name);
std::optional<std::string> get_env_as_lowercase(const std::string & name);
bool parse_bool(const std::string & value);
int parse_integer(const std::string & value);

View File

@ -105,10 +105,10 @@ int32_t ggml_cann_get_device() {
}
/**
* @brief Get the value of the specified environment variable (name).
* @brief Get the value of the specified environment variable (name) as lowercase.
* if not empty, return a std::string object
*/
std::optional<std::string> get_env(const std::string & name) {
std::optional<std::string> get_env_as_lowercase(const std::string & name) {
const char * val = std::getenv(name.c_str());
if (!val) {
return std::nullopt;
@ -259,7 +259,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
* @param device The device ID to associate with this buffer pool.
*/
explicit ggml_cann_pool_buf_prio(int device) : device(device) {
disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
}
/**
@ -452,7 +452,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
* @param device The device ID to associate with this buffer pool.
*/
explicit ggml_cann_pool_buf(int device) : device(device) {
disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
}
/**
@ -764,7 +764,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
* @return A unique pointer to the created CANN pool.
*/
std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(int device) {
std::string mem_pool_type = get_env("GGML_CANN_MEM_POOL").value_or("");
std::string mem_pool_type = get_env_as_lowercase("GGML_CANN_MEM_POOL").value_or("");
if (mem_pool_type == "prio") {
GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device);
@ -1217,7 +1217,7 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer,
// Why aclrtSynchronizeDevice?
// Only check env once.
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
if (!need_transform(tensor->type)) {
ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));
if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
@ -1442,7 +1442,7 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_t
int64_t ne0 = tensor->ne[0];
// Only check env once.
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
// last line must bigger than 32, because every single op deal at
// least 32 bytes.
@ -2136,7 +2136,7 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
#endif // USE_ACL_GRAPH
// Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
// With the use of CANN graphs, the execution will be performed by the graph launch.
static bool opt_fusion = parse_bool(get_env("GGML_CANN_OPERATOR_FUSION").value_or(""));
static bool opt_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or(""));
if (!use_cann_graph || cann_graph_capture_required) {
for (int i = 0; i < cgraph->n_nodes; i++) {
@ -2201,7 +2201,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
#ifdef USE_ACL_GRAPH
bool use_cann_graph = true;
static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
static bool prefill_use_graph = parse_bool(get_env_as_lowercase("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
if (!prefill_use_graph) {
// Do not use acl_graph for prefill.
for (int i = 0; i < cgraph->n_nodes; i++) {

View File

@ -1036,7 +1036,7 @@ struct ggml_tensor_extra_gpu {
#define USE_CUDA_GRAPH
#endif
struct ggml_graph_node_properties {
struct ggml_cuda_graph_node_properties {
void * node_address;
ggml_op node_op;
int64_t ne[GGML_MAX_DIMS];
@ -1061,11 +1061,25 @@ struct ggml_cuda_graph {
std::vector<cudaGraphNode_t> nodes;
bool disable_due_to_gpu_arch = false;
bool disable_due_to_too_many_updates = false;
bool disable_due_to_failed_graph_capture = false;
int number_consecutive_updates = 0;
bool cuda_graphs_enabled = false;
std::vector<ggml_graph_node_properties> ggml_graph_properties;
std::vector<ggml_graph_node_properties> extraneous_srcs_properties;
std::vector<ggml_cuda_graph_node_properties> props;
void record_update(bool use_graph, bool update_required) {
if (use_graph && update_required) {
number_consecutive_updates++;
} else {
number_consecutive_updates = 0;
}
if (number_consecutive_updates >= 4) {
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
disable_due_to_too_many_updates = true;
}
}
bool is_enabled() const {
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates);
}
#endif
};

View File

@ -2853,9 +2853,9 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
}
#ifdef USE_CUDA_GRAPH
static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
bool use_cuda_graph) {
static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
bool use_cuda_graph = true;
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
@ -2915,41 +2915,41 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
return use_cuda_graph;
}
static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
graph_node_properties->node_address = node->data;
graph_node_properties->node_op = node->op;
static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
props->node_address = node->data;
props->node_op = node->op;
for (int i = 0; i < GGML_MAX_DIMS; i++) {
graph_node_properties->ne[i] = node->ne[i];
graph_node_properties->nb[i] = node->nb[i];
props->ne[i] = node->ne[i];
props->nb[i] = node->nb[i];
}
for (int i = 0; i < GGML_MAX_SRC; i++) {
graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
}
memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
}
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
if (node->data != graph_node_properties->node_address &&
static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
if (node->data != props->node_address &&
node->op != GGML_OP_VIEW) {
return false;
}
if (node->op != graph_node_properties->node_op) {
if (node->op != props->node_op) {
return false;
}
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (node->ne[i] != graph_node_properties->ne[i]) {
if (node->ne[i] != props->ne[i]) {
return false;
}
if (node->nb[i] != graph_node_properties->nb[i]) {
if (node->nb[i] != props->nb[i]) {
return false;
}
}
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (node->src[i] &&
node->src[i]->data != graph_node_properties->src_address[i] &&
node->src[i]->data != props->src_address[i] &&
node->op != GGML_OP_VIEW
) {
return false;
@ -2957,56 +2957,55 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
}
if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
return false;
}
return true;
}
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
bool cuda_graph_update_required = false;
bool res = false;
if (cuda_ctx->cuda_graph->instance == nullptr) {
cuda_graph_update_required = true;
res = true;
}
// Check if the graph size has changed
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
cuda_graph_update_required = true;
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes + cgraph->n_leafs);
if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
res = true;
cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
}
// Loop over nodes in GGML graph to determine if CUDA graph update is required
// and store properties to allow this comparison for the next token
for (int i = 0; i < cgraph->n_nodes; i++) {
bool has_matching_properties = true;
if (!cuda_graph_update_required) {
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
bool props_match = true;
if (!res) {
props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]);
}
if (!has_matching_properties) {
cuda_graph_update_required = true;
if (!props_match) {
res = true;
}
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]);
}
for (int i = 0; i < cgraph->n_leafs; i++) {
bool has_matching_properties = true;
if (!cuda_graph_update_required) {
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->leafs[i], &cuda_ctx->cuda_graph->ggml_graph_properties[cgraph->n_nodes + i]);
bool props_match= true;
if (!res) {
props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]);
}
if (!has_matching_properties) {
cuda_graph_update_required = true;
if (!props_match) {
res = true;
}
set_ggml_graph_node_properties(cgraph->leafs[i], &cuda_ctx->cuda_graph->ggml_graph_properties[cgraph->n_nodes + i]);
ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
}
return cuda_graph_update_required;
return res;
}
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) {
#if CUDART_VERSION >= 12000
cudaGraphExecUpdateResultInfo result_info;
@ -3237,10 +3236,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
return false;
}
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) {
bool graph_evaluated_or_captured = false;
// flag used to determine whether it is an integrated_gpu
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context();
bool is_concurrent_event_active = false;
@ -3710,7 +3710,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
}
if (cuda_graph_update_required) { // Update graph executable
update_cuda_graph_executable(cuda_ctx);
ggml_cuda_graph_update_executable(cuda_ctx);
}
// Launch graph
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
@ -3720,43 +3720,25 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
}
}
static bool ggml_cuda_set_cuda_graph_enabled(ggml_backend_cuda_context * cuda_ctx) {
static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) {
#ifdef USE_CUDA_GRAPH
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
// Objects required for CUDA Graph
if (cuda_ctx->cuda_graph == nullptr) {
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
}
bool use_cuda_graph = true;
if (cuda_ctx->cuda_graph->graph == nullptr) {
if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
#endif
}
}
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
// or previous graph capture failure.
// Also disable for multi-gpu for now. TO DO investigate
if (disable_cuda_graphs_due_to_env
|| cuda_ctx->cuda_graph->disable_due_to_gpu_arch
|| cuda_ctx->cuda_graph->disable_due_to_too_many_updates
|| cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
use_cuda_graph = false;
}
cuda_ctx->cuda_graph->cuda_graphs_enabled = use_cuda_graph;
return cuda_ctx->cuda_graph->is_enabled();
#else
bool use_cuda_graph = false;
return false;
#endif // USE_CUDA_GRAPH
return use_cuda_graph;
}
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
@ -3767,30 +3749,14 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
bool use_cuda_graph = false;
bool cuda_graph_update_required = false;
// graph_optimize calls set_cuda_graph_enabled, in-case it not called (i.e. graph_compute is directly called)
// we call it here instead.
#ifdef USE_CUDA_GRAPH
use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx);
use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
if (use_cuda_graph) {
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
if (cuda_ctx->cuda_graph->is_enabled()) {
cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
use_cuda_graph = ggml_cuda_graph_check_compability(cgraph);
use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
if (use_cuda_graph && cuda_graph_update_required) {
cuda_ctx->cuda_graph->number_consecutive_updates++;
} else {
cuda_ctx->cuda_graph->number_consecutive_updates = 0;
}
if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
cuda_ctx->cuda_graph->cuda_graphs_enabled = false;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
#endif
}
cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required);
}
#endif // USE_CUDA_GRAPH
@ -3804,9 +3770,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
}
bool graph_evaluated_or_captured = false;
evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required);
return GGML_STATUS_SUCCESS;
}
@ -3839,7 +3803,7 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
const bool use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx);
const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
static bool enable_graph_optimization = [] {
const char * env = getenv("GGML_CUDA_GRAPH_OPT");

View File

@ -34,13 +34,11 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
// CUDA_GRAPHS_DISABLED
((ncols > 65536) &&
((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
ctx.cuda_graph->disable_due_to_failed_graph_capture)) ||
ctx.cuda_graph->is_enabled())) ||
// CUDA_GRAPHS ENABLED
((ncols > 32768) &&
!((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
ctx.cuda_graph->disable_due_to_failed_graph_capture))) {
ctx.cuda_graph->is_enabled()))) {
#else
(ncols > 65536)) {
#endif // USE_CUDA_GRAPH

View File

@ -333,6 +333,28 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
}
if (amd_wmma_available(cc)) {
// RDNA 4 is consistently worse on rocblas
// https://github.com/ggml-org/llama.cpp/pull/18537#issuecomment-3706422301
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
// High expert counts almost always better on MMQ
// due to a large amount of graph splits
// https://github.com/ggml-org/llama.cpp/pull/18202
if (n_experts >= 64) {
return true;
}
switch (type) {
// These quants are really bad on MMQ
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q6_K:
// These quants are usually worse but not always
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
return ne11 <= 128;
default:
return true;
}
}
return true;
}

View File

@ -114,7 +114,7 @@ __global__ void __launch_bounds__(splitD, 1)
#endif // __clang__
// assumes as many threads as d_state
template <int splitH, int d_state>
template <int c_factor, int d_state>
__global__ void __launch_bounds__(d_state, 1)
ssm_scan_f32_group(
const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
@ -125,20 +125,25 @@ __global__ void __launch_bounds__(d_state, 1)
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
const int head_idx = (blockIdx.x * splitH) / d_head;
const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
const int seq_idx = blockIdx.y;
const int warp = threadIdx.x / WARP_SIZE;
const int lane = threadIdx.x % WARP_SIZE;
const int warp_idx = blockIdx.x * c_factor + warp;
const int head_idx = warp_idx / d_head;
const int head_off = (warp_idx % d_head) * sizeof(float);
const int seq_idx = blockIdx.y;
const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1);
const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH;
float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
// TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase
const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float)));
const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
const float * A_warp = (const float *) ((const char *) src3 + head_idx * src3_nb1);
const float * B_warp = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
const float * C_warp = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
float * y_warp = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx;
float * s_warp = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
// strides across n_seq_tokens
const int stride_x = src1_nb2 / sizeof(float);
@ -147,80 +152,42 @@ __global__ void __launch_bounds__(d_state, 1)
const int stride_C = src5_nb2 / sizeof(float);
const int stride_y = n_head * d_head;
float state[splitH];
// for the parallel accumulation
__shared__ float stateC[splitH * d_state];
float state[c_factor];
float state_sum = 0.0f;
#pragma unroll
for (int j = 0; j < splitH; j++) {
state[j] = s0_block[j * d_state + threadIdx.x];
for (int j = 0; j < c_factor; j++) {
state[j] = s0_warp[WARP_SIZE * j + lane];
}
for (int64_t i = 0; i < n_tok; i++) {
// TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
// TODO: only calculate B and C once per head group
// NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
float dt_soft_plus = dt_block[i * stride_dt];
if (dt_soft_plus <= 20.0f) {
dt_soft_plus = log1pf(expf(dt_soft_plus));
}
const float dA = expf(dt_soft_plus * A_block[0]);
const float B = B_block[i * stride_B + threadIdx.x];
const float C = C_block[i * stride_C + threadIdx.x];
// NOTE: dt_soft_plus, dA and x_dt have the same value for a warp here.
// Recalculation is intentional; sharing via shuffles/smem proved slower due to sync overhead.
const float dt_soft_plus = (dt_warp[i * stride_dt] <= 20.0f ? log1pf(expf(dt_warp[i * stride_dt])) : dt_warp[i * stride_dt]);
// across d_head
state_sum = 0.0f;
const float dA = expf(dt_soft_plus * A_warp[0]);
const float x_dt = x_warp[i * stride_x] * dt_soft_plus;
#pragma unroll
for (int j = 0; j < splitH; j++) {
const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
state[j] = (state[j] * dA) + (B * x_dt);
stateC[j * d_state + threadIdx.x] = state[j] * C;
for (int j = 0; j < c_factor; j++) {
const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane];
const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane];
state[j] = (state[j] * dA) + (B_val * x_dt);
state_sum += state[j] * C_val;
}
__syncthreads();
// parallel accumulation for output
state_sum = warp_reduce_sum(state_sum);
// parallel accumulation for stateC
// TODO: simplify
{
static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
// reduce until w matches the warp size
// TODO: does this work even when the physical warp size is 64?
#pragma unroll
for (int w = d_state; w > WARP_SIZE; w >>= 1) {
// (assuming there are d_state threads)
#pragma unroll
for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
// TODO: check for bank conflicts
const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
stateC[k] += stateC[k + (w >> 1)];
}
__syncthreads();
}
static_assert(splitH >= d_state / WARP_SIZE);
#pragma unroll
for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
y = warp_reduce_sum(y);
// store the above accumulations
if (threadIdx.x % WARP_SIZE == 0) {
const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE);
y_block[i * stride_y + k] = y;
}
}
if (lane == 0) {
y_warp[i * stride_y] = state_sum;
}
}
// write back the state
#pragma unroll
for (int j = 0; j < splitH; j++) {
s_block[j * d_state + threadIdx.x] = state[j];
for (int j = 0; j < c_factor; j++) {
s_warp[WARP_SIZE * j + lane] = state[j];
}
}
@ -231,27 +198,24 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
cudaStream_t stream) {
const int threads = 128;
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
if (src3_nb1 == sizeof(float)) {
// Mamba-2
if (d_state == 128) {
GGML_ASSERT(d_state % threads == 0);
// NOTE: can be any power of two between 4 and 64
const int splitH = 16;
GGML_ASSERT(head_dim % splitH == 0);
const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
ssm_scan_f32_group<16, 128><<<blocks, threads, 0, stream>>>(
constexpr int threads = 128;
constexpr int num_warps = threads/WARP_SIZE;
const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
ssm_scan_f32_group<128/WARP_SIZE, 128><<<blocks, threads, 0, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
} else if (d_state == 256) { // Falcon-H1
const int threads = 256;
// NOTE: can be any power of two between 8 and 64
const int splitH = 16;
GGML_ASSERT(head_dim % splitH == 0);
const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
ssm_scan_f32_group<16, 256><<<blocks, threads, 0, stream>>>(
constexpr int threads = 256;
constexpr int num_warps = threads/WARP_SIZE;
const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
ssm_scan_f32_group<256/WARP_SIZE, 256><<<blocks, threads, 0, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
@ -260,6 +224,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
}
} else {
// Mamba-1
constexpr int threads = 128;
GGML_ASSERT(n_head % threads == 0);
GGML_ASSERT(head_dim == 1);
GGML_ASSERT(n_group == 1);

View File

@ -1773,6 +1773,37 @@ static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_
return true;
}
static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * src1 = op->src[1];
const struct ggml_tensor * src2 = op->src[2];
const struct ggml_tensor * src3 = op->src[3];
const struct ggml_tensor * src4 = op->src[4];
const struct ggml_tensor * dst = op;
// Check for F16 support only as requested
if ((src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_F32) || src1->type != GGML_TYPE_F16 || src2->type != GGML_TYPE_F16) {
return false;
}
if (src3 && src3->type != GGML_TYPE_F16) { // mask
return false;
}
if (src4 && src4->type != GGML_TYPE_F32) { // sinks
return false;
}
// For now we support F32 or F16 output as htp backend often converts output on the fly if needed,
// but the op implementation writes to F16 or F32.
// Let's assume dst can be F32 or F16.
if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) {
return false;
}
return opt_experimental;
}
static bool hex_supported_src0_type(ggml_type t) {
return t == GGML_TYPE_F32;
}
@ -1815,12 +1846,11 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
if (dst->type != GGML_TYPE_F32) {
return false;
}
// TODO: add support for non-cont tensors
if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) {
return false;
}
@ -1836,7 +1866,6 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
return false; // typically the lm-head which would be too large for VTCM
}
// if ((src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3])) return false;
if ((src1->ne[2] != 1 || src1->ne[3] != 1)) {
return false;
}
@ -1885,21 +1914,10 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session
}
break;
case GGML_TYPE_F16:
if (!opt_experimental) {
return false;
}
break;
default:
return false;
}
// TODO: add support for non-cont tensors
if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
return false;
}
return true;
}
@ -2060,6 +2078,46 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s
return true;
}
static bool ggml_hexagon_supported_set_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0]; // values
const struct ggml_tensor * src1 = op->src[1]; // indices
const struct ggml_tensor * dst = op;
if (src0->type != GGML_TYPE_F32) {
return false;
}
if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {
return false;
}
if (dst->type != GGML_TYPE_F16) {
return false;
}
return true;
}
static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0]; // values
const struct ggml_tensor * src1 = op->src[1]; // indices
const struct ggml_tensor * dst = op;
if (src0->type != GGML_TYPE_F32) {
return false;
}
if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {
return false;
}
if (dst->type != GGML_TYPE_F32) {
return false;
}
return true;
}
static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const int32_t * op_params = &op->op_params[0];
@ -2154,6 +2212,11 @@ static size_t htp_req_buff_init(htp_tensor *h, dspqueue_buffer * d, const ggml_t
d->offset = (uint8_t *) t->data - buf->base;
d->size = ggml_nbytes(t);
if (!d->size) {
// Some requests contain srcs where ggml_nbytes() returns 0 but the rest of the op is non-empty
d->size = 64;
}
switch (type) {
case DSPQBUF_TYPE_DSP_WRITE_CPU_READ:
// Flush CPU
@ -2239,6 +2302,17 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu
return n_bufs;
}
static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
req->op = HTP_OP_GET_ROWS;
size_t n_bufs = 0;
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
return n_bufs;
}
template <bool _is_src0_constant>
static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
switch (t->op) {
@ -2266,6 +2340,17 @@ static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer *
return n_bufs;
}
static inline size_t init_set_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
req->op = HTP_OP_SET_ROWS;
size_t n_bufs = 0;
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
return n_bufs;
}
static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
@ -2277,6 +2362,11 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
supported = true;
break;
case GGML_OP_SCALE:
req->op = HTP_OP_SCALE;
supported = true;
break;
case GGML_OP_UNARY:
if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
req->op = HTP_OP_UNARY_SILU;
@ -2331,6 +2421,21 @@ static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs
return n_bufs;
}
static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
req->op = HTP_OP_FLASH_ATTN_EXT;
size_t n_bufs = 0;
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->src3, &bufs[n_bufs], t->src[3], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->src4, &bufs[n_bufs], t->src[4], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
return n_bufs;
}
static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
auto sess = static_cast<ggml_hexagon_session *>(backend->context);
return sess->name.c_str();
@ -2417,6 +2522,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags);
break;
case GGML_OP_RMS_NORM:
case GGML_OP_SCALE:
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
break;
case GGML_OP_UNARY:
@ -2439,6 +2545,18 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
ggml_hexagon_dispatch_op<init_rope_req>(sess, node, flags);
break;
case GGML_OP_FLASH_ATTN_EXT:
ggml_hexagon_dispatch_op<init_flash_attn_ext_req>(sess, node, flags);
break;
case GGML_OP_SET_ROWS:
ggml_hexagon_dispatch_op<init_set_rows_req>(sess, node, flags);
break;
case GGML_OP_GET_ROWS:
ggml_hexagon_dispatch_op<init_get_rows_req>(sess, node, flags);
break;
default:
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
}
@ -2778,6 +2896,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
break;
case GGML_OP_RMS_NORM:
case GGML_OP_SCALE:
supp = ggml_hexagon_supported_unary(sess, op);
break;
@ -2805,6 +2924,18 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
supp = ggml_hexagon_supported_rope(sess, op);
break;
case GGML_OP_FLASH_ATTN_EXT:
supp = ggml_hexagon_supported_flash_attn_ext(sess, op);
break;
case GGML_OP_SET_ROWS:
supp = ggml_hexagon_supported_set_rows(sess, op);
break;
case GGML_OP_GET_ROWS:
supp = ggml_hexagon_supported_get_rows(sess, op);
break;
default:
break;
}

View File

@ -28,6 +28,9 @@ add_library(${HTP_LIB} SHARED
softmax-ops.c
act-ops.c
rope-ops.c
flash-attn-ops.c
set-rows-ops.c
get-rows-ops.c
)
target_compile_definitions(${HTP_LIB} PRIVATE

View File

@ -0,0 +1,566 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#ifdef HTP_DEBUG
# define FARF_HIGH 1
#endif
#include <HAP_farf.h>
#include <HAP_mem.h>
#include <HAP_perf.h>
#include <hexagon_protos.h>
#include <hexagon_types.h>
#include <math.h>
#include <string.h>
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-dma.h"
#include "htp-msg.h"
#include "htp-ops.h"
#include "hvx-utils.h"
#include "ops-utils.h"
// Dot product of FP32 and FP16 vectors, accumulating to float
static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum = Q6_V_vsplat_R(0);
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
// Load y (fp32) and convert into fp16
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
// Load x (fp16)
HVX_Vector x_hf = vx[i];
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
if (nloe) {
// Load y (fp32) and convert into fp16
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
// Load x (fp16)
HVX_Vector x_hf = vx[i];
// Zero-out unused elements
// Note that we need to clear both x and y because they may contain NANs
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
x_hf = Q6_V_vand_QV(bmask, x_hf);
y_hf = Q6_V_vand_QV(bmask, y_hf);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
hvx_vec_store_u(r, 4, rsum);
}
// Dot product of two F16 vectors, accumulating to float
static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum = Q6_V_vsplat_R(0);
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
HVX_Vector y_hf = vy[i];
HVX_Vector x_hf = vx[i];
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
if (nloe) {
HVX_Vector y_hf = vy[i];
// Load x (fp16) and zero-out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
hvx_vec_store_u(r, 4, rsum);
}
// MAD: y (F32) += x (F16) * v (float)
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector S = hvx_vec_splat_fp16(s);
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; ++i) {
// Multiply x * s -> pair of F32 vectors
HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2]));
ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1]));
}
if (nloe) {
HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
HVX_Vector xs = Q6_V_lo_W(xs_p);
i = 2 * i; // index for ptr_y
if (nloe >= 32) {
ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p);
}
if (nloe) {
HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
hvx_vec_store_u(&ptr_y[i], nloe * 4, xy);
}
}
}
#define FLASH_ATTN_BLOCK_SIZE 128
static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) {
const struct htp_tensor * q = &octx->src0;
const struct htp_tensor * k = &octx->src1;
const struct htp_tensor * v = &octx->src2;
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
struct htp_tensor * dst = &octx->dst;
const uint32_t neq0 = q->ne[0];
const uint32_t neq1 = q->ne[1];
const uint32_t neq2 = q->ne[2];
const uint32_t neq3 = q->ne[3];
const uint32_t nek0 = k->ne[0];
const uint32_t nek1 = k->ne[1];
const uint32_t nek2 = k->ne[2];
const uint32_t nek3 = k->ne[3];
const uint32_t nev0 = v->ne[0];
const uint32_t nev1 = v->ne[1];
const uint32_t nev2 = v->ne[2];
const uint32_t nev3 = v->ne[3];
const uint32_t nbq1 = q->nb[1];
const uint32_t nbq2 = q->nb[2];
const uint32_t nbq3 = q->nb[3];
const uint32_t nbk1 = k->nb[1];
const uint32_t nbk2 = k->nb[2];
const uint32_t nbk3 = k->nb[3];
const uint32_t nbv1 = v->nb[1];
const uint32_t nbv2 = v->nb[2];
const uint32_t nbv3 = v->nb[3];
const uint32_t ne1 = dst->ne[1];
const uint32_t ne2 = dst->ne[2];
const uint32_t ne3 = dst->ne[3];
const uint32_t nb1 = dst->nb[1];
const uint32_t nb2 = dst->nb[2];
const uint32_t nb3 = dst->nb[3];
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
if (logit_softcap != 0) {
scale /= logit_softcap;
}
// total rows in q
const uint32_t nr = neq1*neq2*neq3;
const uint32_t dr = (nr + nth - 1) / nth;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = MIN(ir0 + dr, nr);
if (ir0 >= ir1) return;
dma_queue * dma = octx->ctx->dma[ith];
const uint32_t DK = nek0;
const uint32_t DV = nev0;
const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
const size_t size_q_row_padded = htp_round_up(size_q_row, 128);
const size_t size_k_row = DK * sizeof(__fp16);
const size_t size_v_row = DV * sizeof(__fp16);
const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
const size_t size_k_row_padded = htp_round_up(size_k_row, 128);
const size_t size_v_row_padded = htp_round_up(size_v_row, 128);
const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
const size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
// Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
uint8_t * spad_k = octx->src1_spad.data + octx->src1_spad.size_per_thread * ith;
uint8_t * spad_v = octx->src2_spad.data + octx->src2_spad.size_per_thread * ith;
uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith;
const uint32_t n_head = neq2;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
for (uint32_t ir = ir0; ir < ir1; ++ir) {
const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3);
const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2);
const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3);
const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2);
// Fetch Q row
const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
const uint32_t h = iq2; // head index
const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
float S = 0.0f; // sum
float M = -INFINITY; // maximum KQ value
// Clear accumulator
float * VKQ32 = (float *) spad_a;
memset(VKQ32, 0, DV * sizeof(float));
const __fp16 * mp_base = NULL;
if (mask) {
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2);
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3);
mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
}
const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
// Prefetch first two blocks
for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
// K
const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
// V
const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
// Mask
if (mask) {
const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
// Mask is 1D contiguous for this row
dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
}
}
const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
for (uint32_t ib = 0; ib < n_blocks; ++ib) {
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
// Wait for DMA
uint8_t * k_base = dma_queue_pop(dma).dst; // K
uint8_t * v_base = dma_queue_pop(dma).dst; // V
__fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M
// Inner loop processing the block from VTCM
uint32_t ic = 0;
// Process in blocks of 32 (VLEN_FP32)
for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) {
// 1. Compute scores
float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
for (int j = 0; j < VLEN_FP32; ++j) {
const uint32_t cur_ic = ic + j;
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
if (q->type == HTP_TYPE_F32) {
hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
} else {
hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
}
}
HVX_Vector scores = *(HVX_Vector *) scores_arr;
// 2. Softcap
if (logit_softcap != 0.0f) {
scores = hvx_vec_tanh_fp32(scores);
scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_fp32(logit_softcap));
scores = Q6_Vsf_equals_Vqf32(scores);
}
// 3. Mask
if (mask) {
const __fp16 * mp = m_base + ic;
HVX_Vector m_vals_fp16 = *(const HVX_UVector *) mp;
HVX_Vector one_fp16 = Q6_Vh_vsplat_R(0x3c00);
HVX_VectorPair m_vals_fp32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_fp16), one_fp16);
HVX_Vector m_vals_fp32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_fp32_pair));
HVX_Vector slope_vec = hvx_vec_splat_fp32(slope);
HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_fp32, slope_vec);
scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
scores = Q6_Vsf_equals_Vqf32(scores);
}
// 4. Online Softmax Update
HVX_Vector v_max = hvx_vec_reduce_max_fp32(scores);
float m_block = hvx_vec_get_fp32(v_max);
float M_old = M;
float M_new = (m_block > M) ? m_block : M;
M = M_new;
float ms = expf(M_old - M_new);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
S = S * ms;
HVX_Vector M_new_vec = hvx_vec_splat_fp32(M_new);
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
HVX_Vector P = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(scores_shifted));
HVX_Vector p_sum_vec = hvx_vec_fp32_reduce_sum(P);
float p_sum = hvx_vec_get_fp32(p_sum_vec);
S += p_sum;
// 5. Accumulate V
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
*(HVX_Vector*)p_arr = P;
for (int j = 0; j < VLEN_FP32; ++j) {
const uint32_t cur_ic = ic + j;
const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
}
}
// Leftover
for (; ic < current_block_size; ++ic) {
float s_val;
const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
if (q->type == HTP_TYPE_F32) {
hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
} else {
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
}
if (logit_softcap != 0.0f) {
s_val = logit_softcap * tanhf(s_val);
}
if (mask) {
const float m_val = m_base[ic];
s_val += slope * m_val;
}
const float Mold = M;
float ms = 1.0f;
float vs = 1.0f;
if (s_val > M) {
M = s_val;
ms = expf(Mold - M);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
} else {
vs = expf(s_val - M);
}
const uint8_t * v_ptr = v_base + ic * size_v_row_padded;
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
S = S * ms + vs;
}
// Issue DMA for next+1 block (if exists)
if (ib + 2 < n_blocks) {
const uint32_t next_ib = ib + 2;
const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
// K
const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
// V
const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
// Mask
if (mask) {
const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
}
}
}
// sinks
if (sinks) {
const float s = ((float *)((char *) sinks->data))[h];
float ms = 1.0f;
float vs = 1.0f;
if (s > M) {
ms = expf(M - s);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
} else {
vs = expf(s - M);
}
S = S * ms + vs;
}
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, S_inv);
// Store result
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
// dst is permuted
uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;
if (dst->type == HTP_TYPE_F32) {
hvx_copy_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
} else if (dst->type == HTP_TYPE_F16) {
hvx_copy_fp16_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
}
}
}
static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
flash_attn_ext_f16_thread(octx, i, n);
}
int op_flash_attn_ext(struct htp_ops_context * octx) {
const struct htp_tensor * q = &octx->src0;
const struct htp_tensor * k = &octx->src1;
const struct htp_tensor * v = &octx->src2;
const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL;
struct htp_tensor * dst = &octx->dst;
// Check support
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
k->type != HTP_TYPE_F16 ||
v->type != HTP_TYPE_F16) {
return HTP_STATUS_NO_SUPPORT;
}
octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
octx->src0_div1 = init_fastdiv_values(q->ne[1]);
octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
if (mask) {
octx->src3_div2 = init_fastdiv_values(mask->ne[2]);
octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
}
size_t size_q_row_padded = htp_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
size_t size_k_row_padded = htp_round_up(k->ne[0] * sizeof(__fp16), 128);
size_t size_v_row_padded = htp_round_up(v->ne[0] * sizeof(__fp16), 128);
size_t size_q_block = size_q_row_padded * 1; // single row for now
size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
size_t size_vkq_acc = htp_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
octx->src0_spad.size_per_thread = size_q_block * 1;
octx->src1_spad.size_per_thread = size_k_block * 2;
octx->src2_spad.size_per_thread = size_v_block * 2;
octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
octx->dst_spad.size_per_thread = size_vkq_acc;
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
octx->src2_spad.size = octx->src2_spad.size_per_thread * octx->n_threads;
octx->src3_spad.size = octx->src3_spad.size_per_thread * octx->n_threads;
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
size_t total_spad = octx->src0_spad.size + octx->src1_spad.size + octx->src2_spad.size + octx->src3_spad.size + octx->dst_spad.size;
if (octx->ctx->vtcm_size < total_spad) {
return HTP_STATUS_VTCM_TOO_SMALL;
}
octx->src0_spad.data = octx->ctx->vtcm_base;
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads);
}
return HTP_STATUS_OK;
}

View File

@ -0,0 +1,112 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#ifdef HTP_DEBUG
# define FARF_HIGH 1
#endif
#include <HAP_farf.h>
#include <HAP_mem.h>
#include <HAP_perf.h>
#include <hexagon_protos.h>
#include <hexagon_types.h>
#include <math.h>
#include <string.h>
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-msg.h"
#include "htp-ops.h"
#include "hvx-utils.h"
#include "ops-utils.h"
#define get_rows_preamble \
const uint32_t ne00 = octx->src0.ne[0]; \
const uint32_t ne01 = octx->src0.ne[1]; \
const uint32_t ne02 = octx->src0.ne[2]; \
const uint32_t ne03 = octx->src0.ne[3]; \
\
const uint32_t ne10 = octx->src1.ne[0]; \
const uint32_t ne11 = octx->src1.ne[1]; \
const uint32_t ne12 = octx->src1.ne[2]; \
\
const uint32_t nb01 = octx->src0.nb[1]; \
const uint32_t nb02 = octx->src0.nb[2]; \
const uint32_t nb03 = octx->src0.nb[3]; \
\
const uint32_t nb10 = octx->src1.nb[0]; \
const uint32_t nb11 = octx->src1.nb[1]; \
const uint32_t nb12 = octx->src1.nb[2]; \
\
const uint32_t nb1 = octx->dst.nb[1]; \
const uint32_t nb2 = octx->dst.nb[2]; \
const uint32_t nb3 = octx->dst.nb[3]; \
\
const uint32_t nr = ne10 * ne11 * ne12;
static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
get_rows_preamble;
// parallelize by src1 elements (which correspond to dst rows)
const uint32_t dr = octx->src1_nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
for (uint32_t i = ir0; i < ir1; ++i) {
const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11);
const uint32_t rem = i - i12 * ne11 * ne10;
const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10);
const uint32_t i10 = rem - i11 * ne10;
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
if (i01 >= ne01) {
// invalid index, skip for now to avoid crash
continue;
}
const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03;
const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3;
hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
}
return HTP_STATUS_OK;
}
static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
}
int op_get_rows(struct htp_ops_context * octx) {
get_rows_preamble;
if (octx->src0.type != HTP_TYPE_F32) {
return HTP_STATUS_NO_SUPPORT;
}
if (octx->dst.type != HTP_TYPE_F32) {
return HTP_STATUS_NO_SUPPORT;
}
if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
return HTP_STATUS_NO_SUPPORT;
}
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
return HTP_STATUS_OK;
}
octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
const uint32_t n_jobs = MIN(nr, octx->n_threads);
octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs);
return HTP_STATUS_OK;
}

View File

@ -11,11 +11,6 @@
#define HTP_MAX_NTHREADS 10
// FIXME: move these into matmul-ops
#define HTP_SPAD_SRC0_NROWS 16
#define HTP_SPAD_SRC1_NROWS 16
#define HTP_SPAD_DST_NROWS 2
// Main context for htp DSP backend
struct htp_context {
dspqueue_t queue;

View File

@ -36,6 +36,8 @@ enum htp_data_type {
HTP_TYPE_F16 = 1,
HTP_TYPE_Q4_0 = 2,
HTP_TYPE_Q8_0 = 8,
HTP_TYPE_I32 = 26,
HTP_TYPE_I64 = 27,
HTP_TYPE_MXFP4 = 39,
HTP_TYPE_COUNT
};
@ -57,6 +59,10 @@ enum htp_op {
HTP_OP_SOFTMAX = 11,
HTP_OP_ADD_ID = 12,
HTP_OP_ROPE = 13,
HTP_OP_FLASH_ATTN_EXT = 14,
HTP_OP_SET_ROWS = 15,
HTP_OP_SCALE = 16,
HTP_OP_GET_ROWS = 17,
INVALID
};
@ -137,6 +143,8 @@ struct htp_general_req {
struct htp_tensor src0; // Input0 tensor
struct htp_tensor src1; // Input1 tensor
struct htp_tensor src2; // Input2 tensor
struct htp_tensor src3; // Input3 tensor
struct htp_tensor src4; // Input4 tensor
struct htp_tensor dst; // Output tensor
// should be multiple of 64 bytes (cacheline)
@ -152,6 +160,6 @@ struct htp_general_rsp {
};
#define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req)
#define HTP_MAX_PACKET_BUFFERS 4
#define HTP_MAX_PACKET_BUFFERS 8
#endif /* HTP_MSG_H */

View File

@ -13,6 +13,7 @@
struct htp_spad {
uint8_t * data;
size_t stride;
size_t size;
size_t size_per_thread;
};
@ -26,11 +27,14 @@ struct htp_ops_context {
struct htp_tensor src0;
struct htp_tensor src1;
struct htp_tensor src2;
struct htp_tensor src3;
struct htp_tensor src4;
struct htp_tensor dst;
struct htp_spad src0_spad;
struct htp_spad src1_spad;
struct htp_spad src2_spad;
struct htp_spad src3_spad;
struct htp_spad dst_spad;
worker_pool_context_t * wpool; // worker pool
@ -49,6 +53,27 @@ struct htp_ops_context {
struct fastdiv_values src1_div3; // fastdiv values for ne3
struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
struct fastdiv_values src3_div1; // fastdiv values for ne1
struct fastdiv_values src3_div2; // fastdiv values for ne2
struct fastdiv_values src3_div3; // fastdiv values for ne3
struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1
struct fastdiv_values broadcast_rk2;
struct fastdiv_values broadcast_rk3;
struct fastdiv_values broadcast_rv2;
struct fastdiv_values broadcast_rv3;
struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1
struct fastdiv_values mm_div_ne1; // fastdiv values for ne1
struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02
struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03
struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
uint32_t flags;
};
@ -60,5 +85,8 @@ int op_activations(struct htp_ops_context * octx);
int op_softmax(struct htp_ops_context * octx);
int op_add_id(struct htp_ops_context * octx);
int op_rope(struct htp_ops_context * octx);
int op_flash_attn_ext(struct htp_ops_context * octx);
int op_set_rows(struct htp_ops_context * octx);
int op_get_rows(struct htp_ops_context * octx);
#endif /* HTP_OPS_H */

View File

@ -848,55 +848,6 @@ float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems) {
return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v));
}
void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale) {
int left_over = num_elems & (VLEN_FP32 - 1);
int num_elems_whole = num_elems - left_over;
int unaligned_addr = 0;
int unaligned_loop = 0;
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
FARF(HIGH, "hvx_scale_f32: unaligned address in hvx op, possibly slower execution\n");
unaligned_addr = 1;
}
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
unaligned_loop = 1;
FARF(HIGH, "hvx_scale_f32: unaligned loop in hvx op, possibly slower execution\n");
}
HVX_Vector scale_vec = hvx_vec_splat_fp32(scale);
if (0 == unaligned_loop) {
HVX_Vector * vec_in1 = (HVX_Vector *) src;
HVX_Vector * vec_out = (HVX_Vector *) dst;
#pragma unroll(4)
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, scale_vec);
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
}
} else {
#pragma unroll(4)
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec);
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
}
}
if (left_over > 0) {
const float * srcf = (const float *) src + num_elems_whole;
float * dstf = (float *) dst + num_elems_whole;
HVX_Vector in = *(HVX_UVector *) srcf;
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec);
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
}
}
float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) {
int left_over = num_elems & (VLEN_FP32 - 1);
int num_elems_whole = num_elems - left_over;
@ -1065,3 +1016,5 @@ void hvx_clamp_scalar_f32(const uint8_t * restrict src,
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, in_vec);
}
}

View File

@ -41,15 +41,24 @@ static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in)
}
#endif
static inline HVX_Vector hvx_vec_splat_fp32(float i) {
static inline HVX_Vector hvx_vec_splat_fp32(float v) {
union {
float f;
int32_t i;
} fp32 = { .f = i };
float f;
uint32_t i;
} fp32 = { .f = v };
return Q6_V_vsplat_R(fp32.i);
}
static inline HVX_Vector hvx_vec_splat_fp16(float v) {
union {
__fp16 f;
uint16_t i;
} fp16 = { .f = v };
return Q6_Vh_vsplat_R(fp16.i);
}
static inline void hvx_vec_store_u(void * addr, uint32_t n, HVX_Vector v) {
// Rotate as needed.
v = Q6_V_vlalign_VVR(v, v, (size_t) addr);
@ -242,6 +251,120 @@ static inline void hvx_copy_fp32_au(uint8_t * restrict dst, const uint8_t * rest
}
}
// copy n fp32 elements : source is unaligned, destination unaligned
static inline void hvx_copy_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
HVX_UVector * restrict vdst = (HVX_UVector *) dst;
HVX_UVector * restrict vsrc = (HVX_UVector *) src;
assert((unsigned long) dst % 128 == 0);
uint32_t nvec = n / 32;
uint32_t nloe = n % 32;
uint32_t i = 0;
#pragma unroll(4)
for (; i < nvec; i++) {
HVX_Vector v = vsrc[i];
vdst[i] = v;
}
if (nloe) {
HVX_Vector v = vsrc[i];
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v);
}
}
// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is unaligned
static inline void hvx_copy_fp16_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16
HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32
const HVX_Vector zero = Q6_V_vsplat_R(0);
uint32_t nvec = n / 64;
uint32_t nloe = n % 64;
uint32_t i = 0;
#pragma unroll(4)
for (; i < nvec; i++) {
// Load y (fp32) and convert into fp16
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
}
if (nloe) {
// Load y (fp32) and convert into fp16
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
}
}
// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is unaligned
static inline void hvx_copy_fp16_fp32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16
HVX_Vector * restrict vsrc = (HVX_Vector *) src; // fp32
const HVX_Vector zero = Q6_V_vsplat_R(0);
uint32_t nvec = n / 64;
uint32_t nloe = n % 64;
uint32_t i = 0;
#pragma unroll(4)
for (; i < nvec; i++) {
// Load y (fp32) and convert into fp16
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
}
if (nloe) {
// Load y (fp32) and convert into fp16
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
}
}
// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is aligned
static inline void hvx_copy_fp16_fp32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
HVX_Vector * restrict vdst = (HVX_Vector *) dst; // fp16
HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32
const HVX_Vector zero = Q6_V_vsplat_R(0);
uint32_t nvec = n / 64;
uint32_t nloe = n % 64;
uint32_t i = 0;
#pragma unroll(4)
for (; i < nvec; i++) {
// Load y (fp32) and convert into fp16
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
}
if (nloe) {
// Load y (fp32) and convert into fp16
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
}
}
// bcast 1 fp32 element from source to n fp32 elements in destination : destination is aligned
static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t n) {
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
@ -273,8 +396,6 @@ static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint3
return right_off <= chunk_size;
}
static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) {
HVX_VectorAlias u = { .v = v };
@ -531,13 +652,13 @@ static inline HVX_Vector hvx_vec_abs_fp32(HVX_Vector v) {
}
static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) {
#if __HTP_ARCH__ > 75
#if __HVX_ARCH__ > 75
return Q6_Vsf_vfneg_Vsf(v);
#else
// neg by setting the fp32 sign bit
HVX_Vector mask = Q6_V_vsplat_R(0x80000000);
return Q6_V_vxor_VV(v, mask);
#endif // __HTP_ARCH__ > 75
#endif // __HVX_ARCH__ > 75
}
// ====================================================
@ -976,6 +1097,24 @@ static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v,
return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero());
}
static inline HVX_Vector hvx_vec_tanh_fp32(HVX_Vector x) {
// tanh(x) = 2 * sigmoid(2x) - 1
HVX_Vector two = hvx_vec_splat_fp32(2.0f);
HVX_Vector one = hvx_vec_splat_fp32(1.0f);
HVX_Vector x2 = Q6_Vqf32_vmpy_VsfVsf(x, two);
static const float kMinExp = -87.f; // 0
static const float kMaxExp = 87.f; // 1
HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp);
HVX_Vector sig2x = hvx_vec_fast_sigmoid_fp32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp);
HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two);
res = Q6_Vqf32_vsub_Vqf32Vsf(res, one);
return Q6_Vsf_equals_Vqf32(res);
}
static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
int step_of_1 = num_elems >> 5;
int remaining = num_elems - step_of_1 * VLEN_FP32;
@ -1056,6 +1195,115 @@ static inline void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restr
}
}
static inline void hvx_scale_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
int nvec = n / VLEN_FP32;
int nloe = n % VLEN_FP32;
HVX_Vector vs = hvx_vec_splat_fp32(scale);
HVX_Vector * vsrc = (HVX_Vector *) src;
HVX_Vector * vdst = (HVX_Vector *) dst;
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; ++i) {
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
vdst[i] = Q6_Vsf_equals_Vqf32(v);
}
if (nloe) {
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
}
}
static inline void hvx_scale_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
int nvec = n / VLEN_FP32;
int nloe = n % VLEN_FP32;
HVX_Vector vs = hvx_vec_splat_fp32(scale);
HVX_UVector * vsrc = (HVX_UVector *) src;
HVX_UVector * vdst = (HVX_UVector *) dst;
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; ++i) {
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
vdst[i] = Q6_Vsf_equals_Vqf32(v);
}
if (nloe) {
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
}
}
static inline void hvx_scale_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) {
hvx_scale_f32_aa(dst, src, n, scale);
} else {
hvx_scale_f32_uu(dst, src, n, scale);
}
}
static inline void hvx_scale_offset_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
int nvec = n / VLEN_FP32;
int nloe = n % VLEN_FP32;
HVX_Vector vs = hvx_vec_splat_fp32(scale);
HVX_Vector vo = hvx_vec_splat_fp32(offset);
HVX_Vector * vsrc = (HVX_Vector *) src;
HVX_Vector * vdst = (HVX_Vector *) dst;
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; ++i) {
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
vdst[i] = Q6_Vsf_equals_Vqf32(v);
}
if (nloe) {
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
}
}
static inline void hvx_scale_offset_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
int nvec = n / VLEN_FP32;
int nloe = n % VLEN_FP32;
HVX_Vector vs = hvx_vec_splat_fp32(scale);
HVX_Vector vo = hvx_vec_splat_fp32(offset);
HVX_UVector * vsrc = (HVX_UVector *) src;
HVX_UVector * vdst = (HVX_UVector *) dst;
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; ++i) {
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
vdst[i] = Q6_Vsf_equals_Vqf32(v);
}
if (nloe) {
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
}
}
static inline void hvx_scale_offset_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) {
hvx_scale_offset_f32_aa(dst, src, n, scale, offset);
} else {
hvx_scale_offset_f32_uu(dst, src, n, scale, offset);
}
}
float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems);
void hvx_mul_f32(const uint8_t * restrict src0,
@ -1090,7 +1338,6 @@ void hvx_sub_f32_opt(const uint8_t * restrict src0,
uint8_t * restrict dst,
const int num_elems);
void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale);
void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems);
void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems);
void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate);

View File

@ -443,6 +443,45 @@ static void proc_matmul_req(struct htp_context * ctx,
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[1];
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[0].fd = bufs[2].fd;
rsp_bufs[0].ptr = bufs[2].ptr;
rsp_bufs[0].offset = bufs[2].offset;
rsp_bufs[0].size = bufs[2].size;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
// Setup Op context
struct htp_ops_context octx = { 0 };
octx.ctx = ctx;
octx.src0 = req->src0;
octx.src1 = req->src1;
octx.dst = req->dst;
octx.flags = req->flags;
octx.op = req->op;
// Update data pointers
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.src1.data = (uint32_t) bufs[1].ptr;
octx.dst.data = (uint32_t) bufs[2].ptr;
octx.n_threads = ctx->n_threads;
struct profile_data prof;
profile_start(&prof);
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
rsp_status = op_get_rows(&octx);
vtcm_release(ctx);
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_matmul_id_req(struct htp_context * ctx,
struct htp_general_req * req,
struct dspqueue_buffer * bufs,
@ -668,7 +707,7 @@ static void proc_rope_req(struct htp_context * ctx,
uint32_t n_bufs) {
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
int write_idx = (n_bufs == 4) ? 3 : 2;
int write_idx = n_bufs - 1;
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[0].fd = bufs[write_idx].fd;
@ -716,6 +755,102 @@ static void proc_rope_req(struct htp_context * ctx,
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_set_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[1];
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[0].fd = bufs[2].fd;
rsp_bufs[0].ptr = bufs[2].ptr;
rsp_bufs[0].offset = bufs[2].offset;
rsp_bufs[0].size = bufs[2].size;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
// Setup Op context
struct htp_ops_context octx = { 0 };
octx.ctx = ctx;
octx.src0 = req->src0;
octx.src1 = req->src1;
octx.dst = req->dst;
octx.flags = req->flags;
octx.op = req->op;
// Update data pointers
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.src1.data = (uint32_t) bufs[1].ptr;
octx.dst.data = (uint32_t) bufs[2].ptr;
octx.n_threads = ctx->n_threads;
struct profile_data prof;
profile_start(&prof);
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
rsp_status = op_set_rows(&octx);
vtcm_release(ctx);
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_flash_attn_ext_req(struct htp_context * ctx,
struct htp_general_req * req,
struct dspqueue_buffer * bufs,
uint32_t n_bufs) {
// Setup Op context
struct htp_ops_context octx;
memset(&octx, 0, sizeof(octx));
octx.ctx = ctx;
octx.n_threads = ctx->n_threads;
octx.src0 = req->src0;
octx.src1 = req->src1;
octx.src2 = req->src2;
octx.src3 = req->src3;
octx.src4 = req->src4;
octx.dst = req->dst;
octx.flags = req->flags;
octx.op = req->op;
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
// Update data pointers
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.src1.data = (uint32_t) bufs[1].ptr;
octx.src2.data = (uint32_t) bufs[2].ptr;
int last_buf = 3;
if (octx.src3.ne[0]) {
octx.src3.data = (uint32_t) bufs[last_buf++].ptr; // mask is valid
}
if (octx.src4.ne[0]) {
octx.src4.data = (uint32_t) bufs[last_buf++].ptr; // sinks is valid
}
octx.dst.data = (uint32_t) bufs[last_buf].ptr;
struct profile_data prof;
profile_start(&prof);
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
rsp_status = op_flash_attn_ext(&octx);
vtcm_release(ctx);
}
profile_stop(&prof);
struct dspqueue_buffer rsp_buf = bufs[last_buf];
rsp_buf.flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof);
}
static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
struct htp_context * ctx = (struct htp_context *) context;
@ -790,6 +925,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
break;
case HTP_OP_RMS_NORM:
case HTP_OP_SCALE:
if (n_bufs != 2) {
FARF(ERROR, "Bad unary-req buffer list");
continue;
@ -833,6 +969,30 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
proc_rope_req(ctx, &req, bufs, n_bufs);
break;
case HTP_OP_FLASH_ATTN_EXT:
if (!(n_bufs >= 4 && n_bufs <= 6)) {
FARF(ERROR, "Bad flash-attn-ext-req buffer list");
continue;
}
proc_flash_attn_ext_req(ctx, &req, bufs, n_bufs);
break;
case HTP_OP_SET_ROWS:
if (n_bufs != 3) {
FARF(ERROR, "Bad set-rows-req buffer list");
continue;
}
proc_set_rows_req(ctx, &req, bufs);
break;
case HTP_OP_GET_ROWS:
if (n_bufs != 3) {
FARF(ERROR, "Bad get-rows-req buffer list");
continue;
}
proc_get_rows_req(ctx, &req, bufs);
break;
default:
FARF(ERROR, "Unknown Op %u", req.op);
break;

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,168 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#ifdef HTP_DEBUG
# define FARF_HIGH 1
#endif
#include <HAP_farf.h>
#include <HAP_mem.h>
#include <HAP_perf.h>
#include <hexagon_protos.h>
#include <hexagon_types.h>
#include <math.h>
#include <string.h>
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-msg.h"
#include "htp-ops.h"
#include "hvx-utils.h"
#include "ops-utils.h"
#define set_rows_preamble \
const uint32_t ne00 = octx->src0.ne[0]; \
const uint32_t ne01 = octx->src0.ne[1]; \
const uint32_t ne02 = octx->src0.ne[2]; \
const uint32_t ne03 = octx->src0.ne[3]; \
\
const uint32_t ne10 = octx->src1.ne[0]; \
const uint32_t ne11 = octx->src1.ne[1]; \
const uint32_t ne12 = octx->src1.ne[2]; \
\
const uint32_t nb01 = octx->src0.nb[1]; \
const uint32_t nb02 = octx->src0.nb[2]; \
const uint32_t nb03 = octx->src0.nb[3]; \
\
const uint32_t nb10 = octx->src1.nb[0]; \
const uint32_t nb11 = octx->src1.nb[1]; \
const uint32_t nb12 = octx->src1.nb[2]; \
\
const uint32_t nb1 = octx->dst.nb[1]; \
const uint32_t nb2 = octx->dst.nb[2]; \
const uint32_t nb3 = octx->dst.nb[3]; \
\
const uint32_t ne1 = octx->dst.ne[1]; \
\
const uint32_t nr = ne01;
static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
set_rows_preamble;
// parallelize by rows of src0
const uint32_t dr = octx->src0_nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
for (uint32_t i = ir0; i < ir1; ++i) {
const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
const uint32_t i10 = i;
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
if (i1 >= ne1) {
// ignore invalid indices
continue;
}
const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
const uintptr_t dst_ptr = octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
// copy row
hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
}
}
}
return HTP_STATUS_OK;
}
static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) {
set_rows_preamble;
// parallelize by rows of src0
const uint32_t dr = octx->src0_nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
for (uint32_t i = ir0; i < ir1; ++i) {
const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
const uint32_t i10 = i;
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
if (i1 >= ne1) {
// ignore invalid indices
continue;
}
const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
uint8_t* dst_ptr = (uint8_t *) octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
hvx_copy_fp16_fp32_uu(dst_ptr, src0_ptr, ne00);
}
}
}
return HTP_STATUS_OK;
}
static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) {
set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i);
}
static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
}
int op_set_rows(struct htp_ops_context * octx) {
set_rows_preamble;
if (octx->src0.type != HTP_TYPE_F32) {
return HTP_STATUS_NO_SUPPORT;
}
if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) {
return HTP_STATUS_NO_SUPPORT;
}
if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
return HTP_STATUS_NO_SUPPORT;
}
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
return HTP_STATUS_OK;
}
octx->set_rows_div_ne12 = init_fastdiv_values(ne12);
octx->set_rows_div_ne11 = init_fastdiv_values(ne11);
const uint32_t n_jobs = MIN(nr, octx->n_threads);
octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
switch(octx->dst.type) {
case HTP_TYPE_F32:
worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs);
break;
case HTP_TYPE_F16:
worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs);
break;
default:
return HTP_STATUS_NO_SUPPORT;
}
return HTP_STATUS_OK;
}

View File

@ -238,7 +238,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale,
(const uint8_t *) mp_f32, slope);
} else {
hvx_scale_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale);
hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale);
if (mp_f32) {
if (softmax_ctx->use_f16) {
for (int i = 0; i < ne00; ++i) {
@ -258,7 +258,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct
float max = hvx_self_max_f32((const uint8_t *) wp0, ne00);
float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
sum = sum > 0.0 ? (1.0 / sum) : 1;
hvx_scale_f32((const uint8_t *) wp2, (uint8_t *) dp, ne00, sum);
hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
}
}
}

View File

@ -83,6 +83,31 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
}
}
static void scale_htp_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params,
int opt_path) {
float scale = 0.f;
float bias = 0.f;
memcpy(&scale, &op_params[0], sizeof(float));
memcpy(&bias, &op_params[1], sizeof(float));
for (uint32_t ir = 0; ir < num_rows; ir++) {
const float * restrict src_local = src + (ir * row_elems);
float * restrict dst_local = dst + (ir * row_elems);
if (ir + 1 < num_rows) {
htp_l2fetch(src_local + row_elems, 1, row_size, row_size);
}
hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
}
}
static void rms_norm_htp_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
@ -110,7 +135,7 @@ static void rms_norm_htp_f32(const float * restrict src,
const float mean = sum / row_elems;
const float scale = 1.0f / sqrtf(mean + epsilon);
hvx_scale_f32((const uint8_t *) src_local, (uint8_t *) dst_local, row_elems, scale);
hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale);
}
}
}
@ -162,6 +187,9 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
case HTP_OP_RMS_NORM:
rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
case HTP_OP_SCALE:
scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
default:
break;
@ -195,6 +223,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
unary_op_func = unary_job_dispatcher_f32;
op_type = "rmsnorm-f32";
break;
case HTP_OP_SCALE:
unary_op_func = unary_job_dispatcher_f32;
op_type = "scale-f32";
break;
default:
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);

View File

@ -550,6 +550,8 @@ struct vk_device_struct {
uint64_t max_memory_allocation_size;
uint64_t max_buffer_size;
uint64_t suballocation_block_size;
uint64_t min_imported_host_pointer_alignment;
bool external_memory_host {};
bool fp16;
bool bf16;
bool pipeline_robustness;
@ -2410,7 +2412,8 @@ static std::vector<uint32_t> ggml_vk_find_memory_properties(const vk::PhysicalDe
return indices;
}
static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list<vk::MemoryPropertyFlags> & req_flags_list) {
static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list<vk::MemoryPropertyFlags> & req_flags_list,
void *import_ptr = nullptr) {
VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")");
if (size > device->max_buffer_size) {
throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit");
@ -2439,6 +2442,12 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
nullptr,
};
vk::ExternalMemoryBufferCreateInfo external_memory_bci;
if (import_ptr) {
external_memory_bci.handleTypes = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT;
buffer_create_info.setPNext(&external_memory_bci);
}
buf->buffer = device->device.createBuffer(buffer_create_info);
vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer);
@ -2453,35 +2462,80 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
mem_flags_info.setPNext(&mem_priority_info);
}
for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
const auto & req_flags = *it;
const std::vector<uint32_t> memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags);
if (memory_type_indices.empty()) {
continue;
if (import_ptr) {
vk::MemoryHostPointerPropertiesEXT host_pointer_props;
try {
host_pointer_props = device->device.getMemoryHostPointerPropertiesEXT(vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT, import_ptr);
} catch (vk::SystemError& e) {
GGML_LOG_WARN("ggml_vulkan: Failed getMemoryHostPointerPropertiesEXT (%s)\n", e.what());
device->device.destroyBuffer(buf->buffer);
return {};
}
buf->memory_property_flags = req_flags;
vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
bool done = false;
uint32_t memory_type_idx;
vk::MemoryPropertyFlags property_flags = *req_flags_list.begin();
for (memory_type_idx = 0; memory_type_idx < 32; ++memory_type_idx) {
if (!(host_pointer_props.memoryTypeBits & (1u << memory_type_idx))) {
continue;
}
if (!(mem_req.memoryTypeBits & (1u << memory_type_idx))) {
continue;
}
for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) {
try {
buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info });
done = true;
vk::MemoryType memory_type = mem_props.memoryTypes[memory_type_idx];
// check for visible+coherent+cached. Other flags (e.g. devicelocal) are allowed
if ((memory_type.propertyFlags & property_flags) == property_flags) {
property_flags = memory_type.propertyFlags;
break;
} catch (const vk::SystemError& e) {
// loop and retry
// during last attempt throw the exception
if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) {
device->device.destroyBuffer(buf->buffer);
throw e;
}
}
}
if (memory_type_idx == 32) {
GGML_LOG_WARN("ggml_vulkan: Memory type for host allocation not found\n");
device->device.destroyBuffer(buf->buffer);
return {};
}
if (done) {
break;
buf->memory_property_flags = mem_props.memoryTypes[memory_type_idx].propertyFlags;
try {
vk::ImportMemoryHostPointerInfoEXT import_info;
import_info.handleType = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT;
import_info.pHostPointer = import_ptr;
import_info.setPNext(&mem_flags_info);
buf->device_memory = device->device.allocateMemory({ size, memory_type_idx, &import_info });
} catch (const vk::SystemError& e) {
}
} else {
for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
const auto & req_flags = *it;
const std::vector<uint32_t> memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags);
if (memory_type_indices.empty()) {
continue;
}
buf->memory_property_flags = req_flags;
bool done = false;
for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) {
try {
buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info });
done = true;
break;
} catch (const vk::SystemError& e) {
// loop and retry
// during last attempt throw the exception
if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) {
device->device.destroyBuffer(buf->buffer);
throw e;
}
}
}
if (done) {
break;
}
}
}
@ -2492,8 +2546,12 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
buf->ptr = nullptr;
if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
if (import_ptr) {
buf->ptr = import_ptr;
} else {
if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
}
}
device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
@ -4447,6 +4505,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
} else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 &&
getenv("GGML_VK_ENABLE_MEMORY_PRIORITY")) {
device->memory_priority = true;
} else if (strcmp("VK_EXT_external_memory_host", properties.extensionName) == 0) {
device->external_memory_host = true;
}
}
@ -4461,6 +4521,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
vk::PhysicalDeviceVulkan12Properties vk12_props;
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
vk::PhysicalDeviceExternalMemoryHostPropertiesEXT external_memory_host_props;
props2.pNext = &props3;
props3.pNext = &subgroup_props;
@ -4500,11 +4561,22 @@ static vk_device ggml_vk_get_device(size_t idx) {
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
}
if (device->external_memory_host) {
last_struct->pNext = (VkBaseOutStructure *)&external_memory_host_props;
last_struct = (VkBaseOutStructure *)&external_memory_host_props;
}
device->physical_device.getProperties2(&props2);
device->properties = props2.properties;
device->vendor_id = device->properties.vendorID;
device->driver_id = driver_props.driverID;
if (device->driver_id == vk::DriverId::eMoltenvk) {
// Disable external_memory_host until https://github.com/KhronosGroup/MoltenVK/pull/2622
// is available in the Vulkan SDK.
device->external_memory_host = false;
}
// Implementing the async backend interfaces seems broken on older Intel HW,
// see https://github.com/ggml-org/llama.cpp/issues/17302.
device->support_async = (device->vendor_id != VK_VENDOR_ID_INTEL ||
@ -4586,6 +4658,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
device->min_imported_host_pointer_alignment = external_memory_host_props.minImportedHostPointerAlignment;
device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations)));
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
@ -4717,6 +4791,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
device_extensions.push_back("VK_KHR_pipeline_executable_properties");
}
if (device->external_memory_host) {
device_extensions.push_back("VK_EXT_external_memory_host");
}
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
device->pipeline_executable_properties_support = pipeline_executable_properties_support;
@ -14773,6 +14851,51 @@ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggm
VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize");
}
static vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size) {
if (!device->external_memory_host) {
return {};
}
uintptr_t uptr = reinterpret_cast<uintptr_t>(ptr);
if (uptr & (device->min_imported_host_pointer_alignment - 1)) {
return {};
}
if (size & (device->min_imported_host_pointer_alignment - 1)) {
return {};
}
const vk::MemoryPropertyFlags property_flags = vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached;
vk_buffer buf {};
try {
buf = ggml_vk_create_buffer(device, size, { property_flags }, ptr);
} catch (vk::SystemError& e) {
GGML_LOG_WARN("ggml_vulkan: Failed ggml_vk_create_buffer (%s)\n", e.what());
}
return buf;
}
static ggml_backend_buffer_t ggml_backend_vk_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
VK_LOG_DEBUG("ggml_backend_vk_device_buffer_from_host_ptr(backend=" << dev << ", ptr=" << ptr << ", size=" << size << ")");
GGML_UNUSED(max_tensor_size);
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
vk_buffer buf = ggml_vk_buffer_from_host_ptr(device, ptr, size);
if (!buf) {
return {};
}
ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(device, std::move(buf), device->name);
ggml_backend_buffer_t ret = ggml_backend_buffer_init(ggml_backend_vk_device_get_buffer_type(dev), ggml_backend_vk_buffer_interface, bufctx, size);
return ret;
}
static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
/* .get_name = */ ggml_backend_vk_device_get_name,
/* .get_description = */ ggml_backend_vk_device_get_description,
@ -14782,7 +14905,7 @@ static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
/* .init_backend = */ ggml_backend_vk_device_init,
/* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type,
/* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
/* .buffer_from_host_ptr = */ NULL,
/* .buffer_from_host_ptr = */ ggml_backend_vk_device_buffer_from_host_ptr,
/* .supports_op = */ ggml_backend_vk_device_supports_op,
/* .supports_buft = */ ggml_backend_vk_device_supports_buft,
/* .offload_op = */ ggml_backend_vk_device_offload_op,

View File

@ -16,8 +16,14 @@ model="Llama-3.2-3B-Instruct-Q4_0.gguf"
device="HTP0"
[ "$D" != "" ] && device="$D"
verbose=""
[ "$V" != "" ] && verbose="$V"
verbose=
[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" cli_opts="$cli_opts -v"
experimental=
[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E"
profile=
[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" cli_opts="$cli_opts -v"
opmask=
[ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK"
@ -34,7 +40,7 @@ adb $adbserial shell " \
cd $basedir; \
LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$ndev $nhvx $opmask ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \
$ndev $nhvx $opmask $verbose $experimental $profile ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--batch-size 128 -ngl 99 $@ \
--batch-size 128 -ngl 99 $cli_opts $@ \
"

View File

@ -359,6 +359,11 @@ static void llama_params_fit_impl(
// for the first partial layer varying parts can overflow, all further layers use LAYER_FRACTION_MOE:
layer_fraction_t overflow_type = LAYER_FRACTION_MOE;
uint32_t n_full() const {
assert(n_layer >= n_part);
return n_layer - n_part;
}
};
const size_t ntbo = llama_max_tensor_buft_overrides();
@ -382,7 +387,7 @@ static void llama_params_fit_impl(
size_t itbo = 0;
for (size_t id = 0; id < nd; id++) {
il0 += ngl_per_device[id].n_layer - ngl_per_device[id].n_part;
il0 += ngl_per_device[id].n_full();
for (uint32_t il = il0; il < il0 + ngl_per_device[id].n_part; il++) {
if (itbo + 1 >= ntbo) {
tensor_buft_overrides[itbo].pattern = nullptr;
@ -393,7 +398,7 @@ static void llama_params_fit_impl(
+ std::to_string(ntbo) + " is insufficient for model");
}
tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE);
tensor_buft_overrides[itbo].buft = overflow_bufts[id];
tensor_buft_overrides[itbo].buft = il == il0 ? overflow_bufts[id] : ggml_backend_cpu_buffer_type();
itbo++;
}
il0 += ngl_per_device[id].n_part;
@ -468,20 +473,14 @@ static void llama_params_fit_impl(
LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB);
}
std::vector<ggml_backend_buffer_type_t> overflow_bufts; // which bufts the partial layers of a device overflow to:
std::vector<ggml_backend_buffer_type_t> overflow_bufts; // which bufts the first partial layer of a device overflows to:
overflow_bufts.reserve(nd);
for (size_t id = 0; id < nd - 1; ++id) {
overflow_bufts.push_back(ggml_backend_dev_buffer_type(devs[id + 1]));
for (size_t id = 0; id < nd; id++) {
overflow_bufts.push_back(ggml_backend_cpu_buffer_type());
}
overflow_bufts.push_back(ggml_backend_cpu_buffer_type());
std::vector<ngl_t> ngl_per_device(nd);
std::vector<int64_t> mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts);
if (hp_nex > 0) {
for (size_t id = 0; id < nd; id++) {
ngl_per_device[id].overflow_type = LAYER_FRACTION_MOE;
}
}
// optimize the number of layers per device using the method of false position:
// - ngl_per_device has 0 layers for each device, lower bound
@ -512,9 +511,6 @@ static void llama_params_fit_impl(
if (mem_high[id] > targets[id]) {
assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer);
uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
if (hp_nex > 0 && size_t(id) == nd - 1) {
delta--;
}
LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta);
while (delta > 1) {
uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);
@ -524,7 +520,8 @@ static void llama_params_fit_impl(
std::vector<ngl_t> ngl_per_device_test = ngl_per_device;
ngl_per_device_test[id].n_layer += step_size;
if (hp_nex) {
ngl_per_device_test[id].n_part += step_size;
ngl_per_device_test[id].n_part += size_t(id) == nd - 1 && ngl_per_device_test[id].n_part == 0 ?
step_size - 1 : step_size; // the first layer is the output layer which must always be full
}
const std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
@ -573,7 +570,7 @@ static void llama_params_fit_impl(
assert(id_dense_start < nd);
LLAMA_LOG_INFO("%s: converting dense-only layers to full layers and filling them front-to-back with overflow to next device/system memory:\n", __func__);
for (size_t id = 0; id <= id_dense_start; id++) {
for (size_t id = 0; id <= id_dense_start && id_dense_start < nd; id++) {
std::vector<ngl_t> ngl_per_device_high = ngl_per_device;
for (size_t jd = id_dense_start; jd < nd; jd++) {
const uint32_t n_layer_move = jd < nd - 1 ? ngl_per_device_high[jd].n_layer : ngl_per_device_high[jd].n_layer - 1;
@ -585,12 +582,8 @@ static void llama_params_fit_impl(
std::vector<int64_t> mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts);
if (mem_high[id] > targets[id]) {
assert(ngl_per_device_high[id].n_layer >= ngl_per_device_high[id].n_part);
assert(ngl_per_device[id].n_layer >= ngl_per_device[id].n_part);
assert((ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part)
>= ngl_per_device[id].n_layer - ngl_per_device[id].n_part);
uint32_t delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part)
- (ngl_per_device[id].n_layer - ngl_per_device[id].n_part);
assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full());
uint32_t delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full();
while (delta > 1) {
uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);
step_size = std::max(step_size, uint32_t(1));
@ -606,7 +599,7 @@ static void llama_params_fit_impl(
ngl_per_device_test[id].n_layer += n_convert_jd;
n_converted_test += n_convert_jd;
if (ngl_per_device_test[id_dense_start_test].n_layer > 0) {
if (ngl_per_device_test[id_dense_start_test].n_part > 0) {
break;
}
}
@ -625,8 +618,8 @@ static void llama_params_fit_impl(
LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start_high=%zu\n",
__func__, id, ngl_per_device_high[id].n_layer, ngl_per_device_high[id].n_part, id_dense_start_high);
}
delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part)
- (ngl_per_device[id].n_layer - ngl_per_device[id].n_part);
assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full());
delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full();
}
} else {
ngl_per_device = ngl_per_device_high;
@ -644,14 +637,19 @@ static void llama_params_fit_impl(
ngl_per_device_test[id_dense_start_test].n_part--;
ngl_per_device_test[id].n_layer++;
ngl_per_device_test[id].n_part++;
if (ngl_per_device_test[id_dense_start_test].n_layer == 0) {
if (ngl_per_device_test[id_dense_start_test].n_part == 0) {
id_dense_start_test++;
}
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP;
std::vector<ggml_backend_buffer_type_t> overflow_bufts_test = overflow_bufts;
if (id < nd - 1) {
overflow_bufts_test[id] = ggml_backend_dev_buffer_type(devs[id + 1]);
}
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__);
std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test);
if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
ngl_per_device = ngl_per_device_test;
overflow_bufts = overflow_bufts_test;
mem = mem_test;
id_dense_start = id_dense_start_test;
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", UP), id_dense_start=%zu\n",
@ -659,9 +657,10 @@ static void llama_params_fit_impl(
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE;
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__);
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test);
if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
ngl_per_device = ngl_per_device_test;
overflow_bufts = overflow_bufts_test;
mem = mem_test;
id_dense_start = id_dense_start_test;
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", GATE), id_dense_start=%zu\n",
@ -670,9 +669,10 @@ static void llama_params_fit_impl(
} else {
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN;
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__);
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test);
if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
ngl_per_device = ngl_per_device_test;
overflow_bufts = overflow_bufts_test;
mem = mem_test;
id_dense_start = id_dense_start_test;
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", ATTN), id_dense_start=%zu\n",
@ -687,6 +687,14 @@ static void llama_params_fit_impl(
__func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB);
}
// print info for devices that were not changed during the conversion from dense only to full layers:
for (size_t id = id_dense_start + 1; id < nd; id++) {
const int64_t projected_margin = dmds_full[id].free - mem[id];
LLAMA_LOG_INFO(
"%s: - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n",
__func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB);
}
set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams);
}

View File

@ -127,6 +127,15 @@ int main(void) {
assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_SPECULATIVE));
assert(params.speculative.n_max == 123);
// multi-value args (CSV)
argv = {"binary_name", "--lora", "file1.gguf,\"file2,2.gguf\",\"file3\"\"3\"\".gguf\",file4\".gguf"};
assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
assert(params.lora_adapters.size() == 4);
assert(params.lora_adapters[0].path == "file1.gguf");
assert(params.lora_adapters[1].path == "file2,2.gguf");
assert(params.lora_adapters[2].path == "file3\"3\".gguf");
assert(params.lora_adapters[3].path == "file4\".gguf");
// skip this part on windows, because setenv is not supported
#ifdef _WIN32
printf("test-arg-parser: skip on windows build\n");

View File

@ -9,207 +9,250 @@
#include <fstream>
#include <algorithm>
// most of the code here is copied from whisper.cpp
// some of the code here is copied from whisper.cpp
constexpr bool DEBUG = false;
struct mtmd_audio_mel_filters {
int32_t n_mel;
int32_t n_fft;
void mtmd_audio_cache::fill_sin_cos_table(int n) {
sin_vals.resize(n);
cos_vals.resize(n);
for (int i = 0; i < n; i++) {
double theta = (2 * M_PI * i) / n;
sin_vals[i] = sinf(theta);
cos_vals[i] = cosf(theta);
}
}
std::vector<float> data;
};
void mtmd_audio_cache::fill_hann_window(int length, bool periodic) {
hann_window.resize(length);
int offset = -1;
if (periodic) {
offset = 0;
}
for (int i = 0; i < length; i++) {
hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
}
}
// note: this global cache is shared among all preprocessors
// if we want to use multiple preprocessors at the same time,
// we will need to enclose it in the preprocessor class in the future
static struct mtmd_audio_global_cache {
// precomputed sin/cos table for FFT
std::vector<float> sin_vals;
std::vector<float> cos_vals;
// hann window
std::vector<float> hann_window;
// mel filter bank
mtmd_audio_mel_filters filters;
void fill_sin_cos_table(int n) {
sin_vals.resize(n);
cos_vals.resize(n);
for (int i = 0; i < n; i++) {
double theta = (2 * M_PI * i) / n;
sin_vals[i] = sinf(theta);
cos_vals[i] = cosf(theta);
}
void mtmd_audio_cache::fill_mel_filterbank_matrix(int n_mel,
int n_fft,
int sample_rate,
float fmin,
float fmax,
bool slaney_area_norm,
float scale) {
GGML_ASSERT(n_mel > 0 && n_fft > 1);
if (fmax <= 0.0f) {
fmax = 0.5f * sample_rate;
}
void fill_hann_window(int length, bool periodic) {
hann_window.resize(length);
int offset = -1;
if (periodic) {
offset = 0;
}
for (int i = 0; i < length; i++) {
hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
}
// Slaney scale (matches librosa default)
const double min_log_hz = 1000.0;
const double lin_slope = 3 / 200.;
const double min_log_mel = min_log_hz * lin_slope;
const double log_step = log(6.4) / 27.0;
auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double {
return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step;
};
auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double {
return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step);
};
// infer N_fft from n_fft_bins
const double bin_hz_step = double(sample_rate) / double(n_fft);
// mel grid: n_mel + 2 edges
const double m_lo = hz_to_mel(fmin);
const double m_hi = hz_to_mel(fmax);
std::vector<double> mel_pts(n_mel + 2);
for (int i = 0; i < n_mel + 2; ++i) {
mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1));
}
// Build mel filterbank matrix [n_mel × n_fft_bins] at runtime.
// n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257.
void fill_mel_filterbank_matrix(
int n_mel,
int n_fft,
int sample_rate, // e.g. 16000
float fmin = 0.0f, // e.g. 0.0
float fmax = -1.0f, // e.g. sr/2; pass -1 for auto
bool slaney_area_norm = true,
float scale = 1.0f // optional extra scaling; use 1.0f/1000.0f to mimic your code
) {
GGML_ASSERT(n_mel > 0 && n_fft > 1);
if (fmax <= 0.0f) {
fmax = 0.5f * sample_rate;
}
// convert to Hz
std::vector<double> hz_pts(n_mel + 2);
for (int i = 0; i < n_mel + 2; ++i) {
hz_pts[i] = mel_to_hz(mel_pts[i]);
}
// Slaney scale (matches librosa default)
const double min_log_hz = 1000.0;
const double lin_slope = 3 / 200.;
const double min_log_mel = min_log_hz * lin_slope;
const double log_step = log(6.4) / 27.0;
auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double {
return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step;
};
auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double {
return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step);
};
const int n_fft_bins = n_fft / 2 + 1;
// infer N_fft from n_fft_bins
const double bin_hz_step = double(sample_rate) / double(n_fft);
// filterbank
std::vector<float> out(n_mel * n_fft_bins, 0);
for (int m = 0; m < n_mel; ++m) {
const double f_left = hz_pts[m];
const double f_center = hz_pts[m + 1];
const double f_right = hz_pts[m + 2];
// mel grid: n_mel + 2 edges
const double m_lo = hz_to_mel(fmin);
const double m_hi = hz_to_mel(fmax);
std::vector<double> mel_pts(n_mel + 2);
for (int i = 0; i < n_mel + 2; ++i) {
mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1));
}
const double denom_l = std::max(1e-30, f_center - f_left);
const double denom_r = std::max(1e-30, f_right - f_center);
const double enorm = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0;
// convert to Hz
std::vector<double> hz_pts(n_mel + 2);
for (int i = 0; i < n_mel + 2; ++i) {
hz_pts[i] = mel_to_hz(mel_pts[i]);
}
const int n_fft_bins = n_fft / 2 + 1;
// filterbank
std::vector<float> out(n_mel * n_fft_bins, 0);
for (int m = 0; m < n_mel; ++m) {
const double f_left = hz_pts[m];
const double f_center = hz_pts[m + 1];
const double f_right = hz_pts[m + 2];
const double denom_l = std::max(1e-30, f_center - f_left);
const double denom_r = std::max(1e-30, f_right - f_center);
const double enorm = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0;
for (int k = 0; k < n_fft_bins; ++k) {
const double f = k * bin_hz_step;
double w = 0.0;
if (f >= f_left && f <= f_center) {
w = (f - f_left) / denom_l;
} else if (f > f_center && f <= f_right) {
w = (f_right - f) / denom_r;
}
out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale);
for (int k = 0; k < n_fft_bins; ++k) {
const double f = k * bin_hz_step;
double w = 0.0;
if (f >= f_left && f <= f_center) {
w = (f - f_left) / denom_l;
} else if (f > f_center && f <= f_right) {
w = (f_right - f) / denom_r;
}
out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale);
}
}
filters.n_mel = n_mel;
filters.n_fft = n_fft;
filters.data = std::move(out);
filters.n_mel = n_mel;
filters.n_fft = n_fft;
filters.data = std::move(out);
if (DEBUG) { // debug
for (size_t i = 0; i < filters.data.size(); ++i) {
if (filters.data[i] != 0.0f) {
printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f);
}
if (DEBUG) { // debug
for (size_t i = 0; i < filters.data.size(); ++i) {
if (filters.data[i] != 0.0f) {
printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f);
}
}
}
} g_cache;
}
// naive Discrete Fourier Transform
// input is real-valued
// output is complex-valued
static void dft(const float * in, int N, float * out) {
const int n_sin_cos_vals = g_cache.sin_vals.size();
const int sin_cos_step = n_sin_cos_vals / N;
// Unified DFT implementation for both forward and inverse transforms
// Template parameters:
// Inverse: false = DFT with exp(-2πi·k·n/N), no scaling
// true = IDFT with exp(+2πi·k·n/N), scales by 1/N
// RealInput: true = input is real-valued (stride 1), avoids imaginary computations
// false = input is complex-valued (interleaved real/imag, stride 2)
template <bool Inverse, bool RealInput>
static void dft_impl(const mtmd_audio_cache & cache, const float * in, int N, float * out) {
const int n_sin_cos_vals = cache.sin_vals.size();
const int sin_cos_step = n_sin_cos_vals / N;
constexpr float sign = Inverse ? 1.0f : -1.0f;
const float scale = Inverse ? (1.0f / N) : 1.0f;
for (int k = 0; k < N; k++) {
float re = 0;
float im = 0;
for (int n = 0; n < N; n++) {
int idx = (k * n * sin_cos_step) % (n_sin_cos_vals); // t = 2*M_PI*k*n/N
re += in[n] * g_cache.cos_vals[idx]; // cos(t)
im -= in[n] * g_cache.sin_vals[idx]; // sin(t)
int idx = (k * n * sin_cos_step) % n_sin_cos_vals;
float cos_val = cache.cos_vals[idx];
float sin_val = cache.sin_vals[idx];
if constexpr (RealInput) {
// Real input: in_im = 0, simplifies to:
// re += in_re * cos_val
// im += sign * in_re * sin_val
float in_re = in[n];
re += in_re * cos_val;
im += sign * in_re * sin_val;
} else {
float in_re = in[n * 2 + 0];
float in_im = in[n * 2 + 1];
// (a + bi) * (cos + sign*i*sin) = (a*cos - sign*b*sin) + (sign*a*sin + b*cos)i
re += in_re * cos_val - sign * in_im * sin_val;
im += sign * in_re * sin_val + in_im * cos_val;
}
}
out[k*2 + 0] = re;
out[k*2 + 1] = im;
out[k * 2 + 0] = re * scale;
out[k * 2 + 1] = im * scale;
}
}
// Cooley-Tukey FFT
// poor man's implementation - use something better
// input is real-valued
// output is complex-valued
static void fft(float * in, int N, float * out) {
const int n_sin_cos_vals = g_cache.sin_vals.size();
// Cooley-Tukey FFT/IFFT unified implementation
// Template parameters:
// Inverse: false = FFT with exp(-2πi·k/N), no scaling
// true = IFFT with exp(+2πi·k/N), scales by 0.5 at each level
// RealInput: true = input is real-valued (stride 1)
// false = input is complex-valued (interleaved real/imag, stride 2)
template <bool Inverse, bool RealInput>
static void fft_impl(const mtmd_audio_cache & cache, float * in, int N, float * out) {
const int n_sin_cos_vals = cache.sin_vals.size();
if (N == 1) {
out[0] = in[0];
out[1] = 0;
if constexpr (RealInput) {
out[1] = 0.0f;
} else {
out[1] = in[1];
}
return;
}
const int half_N = N / 2;
if (N - half_N*2 == 1) {
dft(in, N, out);
if (N - half_N * 2 == 1) {
// Odd N: fall back to DFT
dft_impl<Inverse, RealInput>(cache, in, N, out);
return;
}
float* even = in + N;
for (int i = 0; i < half_N; ++i) {
even[i]= in[2*i];
}
float* even_fft = out + 2 * N;
fft(even, half_N, even_fft);
// Split into even and odd
if constexpr (RealInput) {
// Real input: stride is 1, copy only real values
float * even = in + N;
for (int i = 0; i < half_N; ++i) {
even[i] = in[2 * i];
}
float * even_fft = out + 2 * N;
fft_impl<Inverse, true>(cache, even, half_N, even_fft);
float* odd = even;
for (int i = 0; i < half_N; ++i) {
odd[i] = in[2*i + 1];
float * odd = even;
for (int i = 0; i < half_N; ++i) {
odd[i] = in[2 * i + 1];
}
float * odd_fft = even_fft + N;
fft_impl<Inverse, true>(cache, odd, half_N, odd_fft);
} else {
// Complex input: stride is 2, copy complex pairs
float * even = in + N * 2;
for (int i = 0; i < half_N; ++i) {
even[i * 2 + 0] = in[2 * i * 2 + 0];
even[i * 2 + 1] = in[2 * i * 2 + 1];
}
float * even_fft = out + 2 * N;
fft_impl<Inverse, false>(cache, even, half_N, even_fft);
float * odd = even;
for (int i = 0; i < half_N; ++i) {
odd[i * 2 + 0] = in[(2 * i + 1) * 2 + 0];
odd[i * 2 + 1] = in[(2 * i + 1) * 2 + 1];
}
float * odd_fft = even_fft + N;
fft_impl<Inverse, false>(cache, odd, half_N, odd_fft);
}
float* odd_fft = even_fft + N;
fft(odd, half_N, odd_fft);
float * even_fft = out + 2 * N;
float * odd_fft = even_fft + N;
const int sin_cos_step = n_sin_cos_vals / N;
constexpr float sign = Inverse ? 1.0f : -1.0f;
constexpr float scale = Inverse ? 0.5f : 1.0f;
for (int k = 0; k < half_N; k++) {
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
float re = g_cache.cos_vals[idx]; // cos(t)
float im = -g_cache.sin_vals[idx]; // sin(t)
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
float re = cache.cos_vals[idx];
float im = sign * cache.sin_vals[idx];
float re_odd = odd_fft[2*k + 0];
float im_odd = odd_fft[2*k + 1];
float re_odd = odd_fft[2 * k + 0];
float im_odd = odd_fft[2 * k + 1];
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
out[2 * k + 0] = scale * (even_fft[2 * k + 0] + re * re_odd - im * im_odd);
out[2 * k + 1] = scale * (even_fft[2 * k + 1] + re * im_odd + im * re_odd);
out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
out[2 * (k + half_N) + 0] = scale * (even_fft[2 * k + 0] - re * re_odd + im * im_odd);
out[2 * (k + half_N) + 1] = scale * (even_fft[2 * k + 1] - re * im_odd - im * re_odd);
}
}
// Forward FFT for real input (used by mel spectrogram)
static void fft(const mtmd_audio_cache & cache, float * in, int N, float * out) {
fft_impl<false, true>(cache, in, N, out);
}
// Inverse FFT for complex input
static void ifft(const mtmd_audio_cache & cache, float * in, int N, float * out) {
fft_impl<true, false>(cache, in, N, out);
}
struct filter_params {
int32_t n_mel;
int32_t n_fft_bins;
@ -222,20 +265,27 @@ struct filter_params {
bool norm_per_feature = false;
};
static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
int n_samples, int frame_size, int frame_step, int n_threads,
const filter_params & params, mtmd_audio_mel & out) {
static void log_mel_spectrogram_worker_thread(int ith,
const float * hann,
const std::vector<float> & samples,
int n_samples,
int frame_size,
int frame_step,
int n_threads,
const filter_params & params,
const mtmd_audio_cache & cache,
mtmd_audio_mel & out) {
std::vector<float> fft_in(frame_size * 2, 0.0);
std::vector<float> fft_out(frame_size * 2 * 2 * 2);
int n_fft_bins = params.n_fft_bins;
int i = ith;
const auto & filters = g_cache.filters;
const auto & filters = cache.filters;
// make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
GGML_ASSERT(n_fft_bins == 1 + (frame_size / 2));
GGML_ASSERT(g_cache.sin_vals.size() == g_cache.cos_vals.size());
GGML_ASSERT(cache.sin_vals.size() == cache.cos_vals.size());
// calculate FFT only when fft_in are not all zero
for (; i < std::min(n_samples / frame_step + 1, out.n_len); i += n_threads) {
const int offset = i * frame_step;
@ -251,7 +301,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const
}
// FFT
fft(fft_in.data(), frame_size, fft_out.data());
fft(cache, fft_in.data(), frame_size, fft_out.data());
// Calculate modulus^2 of complex numbers
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
@ -298,6 +348,7 @@ static bool log_mel_spectrogram(
const int n_samples_in,
const int n_threads,
const filter_params & params,
const mtmd_audio_cache & cache,
mtmd_audio_mel & out) {
//const int64_t t_start_us = ggml_time_us();
@ -305,9 +356,9 @@ static bool log_mel_spectrogram(
int n_samples = n_samples_in;
// Hann window
const float * hann = g_cache.hann_window.data();
const int frame_size = (params.n_fft_bins - 1) * 2;
const int frame_step = params.hop_length;
const float * hann = cache.hann_window.data();
const int frame_size = (params.n_fft_bins - 1) * 2;
const int frame_step = params.hop_length;
// Padding
std::vector<float> samples_padded;
@ -335,9 +386,9 @@ static bool log_mel_spectrogram(
// preemphasis
if (params.preemph) {
const int pad_amount = frame_size / 2;
const int pad_amount = frame_size / 2;
const float preemph = 0.97f;
float prev = samples_padded[pad_amount];
float prev = samples_padded[pad_amount];
for (int i = pad_amount + 1; i + pad_amount < n_samples; ++i) {
float cur = samples_padded[i];
samples_padded[i] = cur - preemph * prev;
@ -372,14 +423,14 @@ static bool log_mel_spectrogram(
{
std::vector<std::thread> workers(n_threads - 1);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw] = std::thread(
log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
n_samples, frame_size, frame_step, n_threads,
std::cref(params), std::ref(out));
workers[iw] =
std::thread(log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), n_samples,
frame_size, frame_step, n_threads, std::cref(params), std::cref(cache), std::ref(out));
}
// main thread
log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params, out);
log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params,
cache, out);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw].join();
}
@ -404,7 +455,7 @@ static bool log_mel_spectrogram(
for (int j = 0; j < effective_n_len; ++j) {
auto &value = out.data[i * out.n_len + j];
value = (value - mean) / mstd;
value = (value - mean) / mstd;
}
// pad the rest with zeros
@ -450,18 +501,14 @@ static bool log_mel_spectrogram(
//
void mtmd_audio_preprocessor_whisper::initialize() {
g_cache.fill_sin_cos_table(hparams.audio_n_fft);
g_cache.fill_hann_window(hparams.audio_window_len, true);
g_cache.fill_mel_filterbank_matrix(
hparams.n_mel_bins,
hparams.audio_n_fft,
hparams.audio_sample_rate);
cache.fill_sin_cos_table(hparams.audio_n_fft);
cache.fill_hann_window(hparams.audio_window_len, true);
cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate);
}
bool mtmd_audio_preprocessor_whisper::preprocess(
const float * samples,
size_t n_samples,
std::vector<mtmd_audio_mel> & output) {
bool mtmd_audio_preprocessor_whisper::preprocess(const float * samples,
size_t n_samples,
std::vector<mtmd_audio_mel> & output) {
if (n_samples == 0) {
// empty audio
return false;
@ -471,7 +518,7 @@ bool mtmd_audio_preprocessor_whisper::preprocess(
// if input is too short, pad with zeros
// this is to avoid potential issues with stage1/2 padding in log_mel_spectrogram
// TODO: maybe handle this better
size_t min_samples = (size_t)hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin
size_t min_samples = (size_t) hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin
if (n_samples < min_samples) {
smpl.resize(min_samples, 0.0f);
std::memcpy(smpl.data(), samples, n_samples * sizeof(float));
@ -486,22 +533,19 @@ bool mtmd_audio_preprocessor_whisper::preprocess(
params.hop_length = hparams.audio_hop_len;
params.sample_rate = hparams.audio_sample_rate;
params.center_padding = false;
params.preemph = 0.0f; // disabled
params.preemph = 0.0f; // disabled
params.use_natural_log = false;
params.norm_per_feature = false;
// make sure the global cache is initialized
GGML_ASSERT(!g_cache.sin_vals.empty());
GGML_ASSERT(!g_cache.cos_vals.empty());
GGML_ASSERT(!g_cache.filters.data.empty());
// make sure the cache is initialized
GGML_ASSERT(!cache.sin_vals.empty());
GGML_ASSERT(!cache.cos_vals.empty());
GGML_ASSERT(!cache.filters.data.empty());
mtmd_audio_mel out_full;
bool ok = log_mel_spectrogram(
samples,
n_samples,
4, // n_threads
params,
out_full);
bool ok = log_mel_spectrogram(samples, n_samples,
4, // n_threads
params, cache, out_full);
if (!ok) {
return false;
}
@ -512,21 +556,21 @@ bool mtmd_audio_preprocessor_whisper::preprocess(
printf("output: n_mel = %d, n_len = %d\n", out_full.n_mel, out_full.n_len);
}
const size_t frames_per_chunk = 3000;
GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk);
for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) {
int n_len = std::min(frames_per_chunk, (size_t)out_full.n_len - off);
if ((size_t)n_len < frames_per_chunk) {
break; // last uncomplete chunk will always be a padded chunk, safe to ignore
GGML_ASSERT((size_t) out_full.n_len > frames_per_chunk);
for (size_t off = 0; off < (size_t) out_full.n_len; off += frames_per_chunk) {
int n_len = std::min(frames_per_chunk, (size_t) out_full.n_len - off);
if ((size_t) n_len < frames_per_chunk) {
break; // last uncomplete chunk will always be a padded chunk, safe to ignore
}
mtmd_audio_mel out_chunk;
out_chunk.n_len = n_len;
out_chunk.n_mel = out_full.n_mel;
out_chunk.n_len_org = out_full.n_mel; // unused
out_chunk.n_len_org = out_full.n_mel; // unused
out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len);
for (int i = 0; i < out_full.n_mel; i++) {
auto src = out_full.data.begin() + i*out_full.n_len + off;
auto src = out_full.data.begin() + i * out_full.n_len + off;
out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk);
}
@ -541,18 +585,14 @@ bool mtmd_audio_preprocessor_whisper::preprocess(
//
void mtmd_audio_preprocessor_conformer::initialize() {
g_cache.fill_sin_cos_table(hparams.audio_n_fft);
g_cache.fill_hann_window(hparams.audio_window_len, true);
g_cache.fill_mel_filterbank_matrix(
hparams.n_mel_bins,
hparams.audio_n_fft,
hparams.audio_sample_rate);
cache.fill_sin_cos_table(hparams.audio_n_fft);
cache.fill_hann_window(hparams.audio_window_len, true);
cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate);
}
bool mtmd_audio_preprocessor_conformer::preprocess(
const float * samples,
size_t n_samples,
std::vector<mtmd_audio_mel> & output) {
bool mtmd_audio_preprocessor_conformer::preprocess(const float * samples,
size_t n_samples,
std::vector<mtmd_audio_mel> & output) {
// empty audio
if (n_samples == 0) {
return false;
@ -569,18 +609,15 @@ bool mtmd_audio_preprocessor_conformer::preprocess(
params.use_natural_log = true;
params.norm_per_feature = true;
// make sure the global cache is initialized
GGML_ASSERT(!g_cache.sin_vals.empty());
GGML_ASSERT(!g_cache.cos_vals.empty());
GGML_ASSERT(!g_cache.filters.data.empty());
// make sure the cache is initialized
GGML_ASSERT(!cache.sin_vals.empty());
GGML_ASSERT(!cache.cos_vals.empty());
GGML_ASSERT(!cache.filters.data.empty());
mtmd_audio_mel out_full;
bool ok = log_mel_spectrogram(
samples,
n_samples,
4, // n_threads
params,
out_full);
bool ok = log_mel_spectrogram(samples, n_samples,
4, // n_threads
params, cache, out_full);
if (!ok) {
return false;
}
@ -588,3 +625,106 @@ bool mtmd_audio_preprocessor_conformer::preprocess(
output.push_back(std::move(out_full));
return true;
}
//
// mtmd_audio_streaming_istft implementation
//
mtmd_audio_streaming_istft::mtmd_audio_streaming_istft(int n_fft, int hop_length) :
n_fft(n_fft),
hop_length(hop_length),
n_fft_bins(n_fft / 2 + 1),
overlap_buffer(n_fft, 0.0f),
window_sum_buffer(n_fft, 0.0f),
padding_to_remove((n_fft - hop_length) / 2),
ifft_in(n_fft * 2 * 4, 0.0f), // extra space for recursive IFFT
ifft_out(n_fft * 2 * 4, 0.0f) {
cache.fill_sin_cos_table(n_fft);
cache.fill_hann_window(n_fft, true);
}
void mtmd_audio_streaming_istft::reset() {
std::fill(overlap_buffer.begin(), overlap_buffer.end(), 0.0f);
std::fill(window_sum_buffer.begin(), window_sum_buffer.end(), 0.0f);
padding_to_remove = (n_fft - hop_length) / 2;
}
std::vector<float> mtmd_audio_streaming_istft::process_frame(const float * frame_spectrum) {
std::vector<float> output(hop_length);
// copy frequencies
for (int j = 0; j < n_fft_bins; j++) {
ifft_in[j * 2 + 0] = frame_spectrum[j * 2 + 0];
ifft_in[j * 2 + 1] = frame_spectrum[j * 2 + 1];
}
// mirror negative frequencies
for (int j = 1; j < n_fft_bins - 1; j++) {
int mirror_idx = n_fft - j;
ifft_in[mirror_idx * 2 + 0] = ifft_in[j * 2 + 0];
ifft_in[mirror_idx * 2 + 1] = -ifft_in[j * 2 + 1]; // conjugate
}
ifft(cache, ifft_in.data(), n_fft, ifft_out.data());
// update window sum and overlap buffer
for (int j = 0; j < n_fft; j++) {
window_sum_buffer[j] += cache.hann_window[j] * cache.hann_window[j];
overlap_buffer[j] += ifft_out[j * 2] * cache.hann_window[j];
}
// extract hop_length samples with normalization
for (int i = 0; i < hop_length; i++) {
if (window_sum_buffer[i] > 1e-8f) {
output[i] = overlap_buffer[i] / window_sum_buffer[i];
} else {
output[i] = overlap_buffer[i];
}
}
// shift buffers left by hop_length
std::copy(overlap_buffer.begin() + hop_length, overlap_buffer.end(), overlap_buffer.begin());
std::fill(overlap_buffer.end() - hop_length, overlap_buffer.end(), 0.0f);
std::copy(window_sum_buffer.begin() + hop_length, window_sum_buffer.end(), window_sum_buffer.begin());
std::fill(window_sum_buffer.end() - hop_length, window_sum_buffer.end(), 0.0f);
// Remove padding if needed
int to_remove = std::min(padding_to_remove, (int) output.size());
padding_to_remove -= to_remove;
output.erase(output.begin(), output.begin() + to_remove);
return output;
}
std::vector<float> mtmd_audio_streaming_istft::flush() {
std::vector<float> output;
// Extract remaining samples from overlap buffer
// Continue until we've extracted all meaningful samples
int remaining = n_fft - hop_length;
while (remaining > 0) {
int chunk_size = std::min(remaining, hop_length);
for (int i = 0; i < chunk_size; i++) {
float sample;
if (window_sum_buffer[i] > 1e-8f) {
sample = overlap_buffer[i] / window_sum_buffer[i];
} else {
sample = overlap_buffer[i];
}
output.push_back(sample);
}
// Shift buffers
std::copy(overlap_buffer.begin() + chunk_size, overlap_buffer.end(), overlap_buffer.begin());
std::fill(overlap_buffer.end() - chunk_size, overlap_buffer.end(), 0.0f);
std::copy(window_sum_buffer.begin() + chunk_size, window_sum_buffer.end(), window_sum_buffer.begin());
std::fill(window_sum_buffer.end() - chunk_size, window_sum_buffer.end(), 0.0f);
remaining -= chunk_size;
}
return output;
}

View File

@ -17,6 +17,38 @@ struct mtmd_audio_mel {
std::vector<float> data;
};
struct mtmd_audio_mel_filters {
int32_t n_mel;
int32_t n_fft;
std::vector<float> data;
};
// cache for audio processing, each processor instance owns its own cache
struct mtmd_audio_cache {
std::vector<float> sin_vals;
std::vector<float> cos_vals;
std::vector<float> hann_window;
mtmd_audio_mel_filters filters;
void fill_sin_cos_table(int n);
void fill_hann_window(int length, bool periodic);
// Build mel filterbank matrix [n_mel × n_fft_bins] at runtime.
// n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257.
void fill_mel_filterbank_matrix(int n_mel,
int n_fft,
int sample_rate, // e.g. 16000
float fmin = 0.0f, // e.g. 0.0
float fmax = -1.0f, // e.g. sr/2; pass -1 for auto
bool slaney_area_norm = true,
float scale = 1.0f // optional extra scaling
);
};
struct mtmd_audio_preprocessor {
const clip_hparams & hparams;
@ -31,10 +63,51 @@ struct mtmd_audio_preprocessor_whisper : mtmd_audio_preprocessor {
mtmd_audio_preprocessor_whisper(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {}
void initialize() override;
bool preprocess(const float * samples, size_t n_samples, std::vector<mtmd_audio_mel> & output) override;
private:
mtmd_audio_cache cache;
};
struct mtmd_audio_preprocessor_conformer : mtmd_audio_preprocessor {
mtmd_audio_preprocessor_conformer(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {}
void initialize() override;
bool preprocess(const float * samples, size_t n_samples, std::vector<mtmd_audio_mel> & output) override;
private:
mtmd_audio_cache cache;
};
//
// streaming ISTFT - converts spectrogram frames back to audio one frame at a time
//
struct mtmd_audio_streaming_istft {
mtmd_audio_streaming_istft(int n_fft, int hop_length);
// reset streaming state
void reset();
// process a single STFT frame (streaming)
// frame_spectrum: [n_fft_bins x 2] interleaved real/imag
// returns: up to hop_length samples
std::vector<float> process_frame(const float * frame_spectrum);
// flush remaining samples at end of stream
std::vector<float> flush();
private:
int n_fft;
int hop_length;
int n_fft_bins;
// Own cache for output processing
mtmd_audio_cache cache;
// Streaming state
std::vector<float> overlap_buffer;
std::vector<float> window_sum_buffer;
int padding_to_remove;
// Working buffers for IFFT
std::vector<float> ifft_in;
std::vector<float> ifft_out;
};

View File

@ -814,6 +814,15 @@ json server_task_result_cmpl_final::to_json_anthropic() {
msg.content = content;
}
// thinking block comes first (Anthropic extended thinking format)
if (!msg.reasoning_content.empty()) {
content_blocks.push_back({
{"type", "thinking"},
{"thinking", msg.reasoning_content},
{"signature", ""} // empty signature for local models (no cryptographic verification)
});
}
if (!msg.content.empty()) {
content_blocks.push_back({
{"type", "text"},
@ -862,20 +871,57 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() {
stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use";
}
bool has_text = !oaicompat_msg.content.empty();
bool has_thinking = !oaicompat_msg.reasoning_content.empty();
bool has_text = !oaicompat_msg.content.empty();
size_t num_tool_calls = oaicompat_msg.tool_calls.size();
bool text_block_started = false;
// content block indices: thinking (0) -> text (0 or 1) -> tool_use (n+)
size_t thinking_block_index = 0;
size_t text_block_index = has_thinking ? 1 : 0;
bool thinking_block_started = false;
bool text_block_started = false;
std::unordered_set<size_t> tool_calls_started;
for (const auto & diff : oaicompat_msg_diffs) {
// handle thinking/reasoning content
if (!diff.reasoning_content_delta.empty()) {
if (!thinking_block_started) {
events.push_back({
{"event", "content_block_start"},
{"data", {
{"type", "content_block_start"},
{"index", thinking_block_index},
{"content_block", {
{"type", "thinking"},
{"thinking", ""}
}}
}}
});
thinking_block_started = true;
}
events.push_back({
{"event", "content_block_delta"},
{"data", {
{"type", "content_block_delta"},
{"index", thinking_block_index},
{"delta", {
{"type", "thinking_delta"},
{"thinking", diff.reasoning_content_delta}
}}
}}
});
}
// handle regular text content
if (!diff.content_delta.empty()) {
if (!text_block_started) {
events.push_back({
{"event", "content_block_start"},
{"data", {
{"type", "content_block_start"},
{"index", 0},
{"index", text_block_index},
{"content_block", {
{"type", "text"},
{"text", ""}
@ -889,7 +935,7 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() {
{"event", "content_block_delta"},
{"data", {
{"type", "content_block_delta"},
{"index", 0},
{"index", text_block_index},
{"delta", {
{"type", "text_delta"},
{"text", diff.content_delta}
@ -898,8 +944,9 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() {
});
}
// handle tool calls
if (diff.tool_call_index != std::string::npos) {
size_t content_block_index = (has_text ? 1 : 0) + diff.tool_call_index;
size_t content_block_index = (has_thinking ? 1 : 0) + (has_text ? 1 : 0) + diff.tool_call_index;
if (tool_calls_started.find(diff.tool_call_index) == tool_calls_started.end()) {
const auto & full_tool_call = oaicompat_msg.tool_calls[diff.tool_call_index];
@ -935,18 +982,42 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() {
}
}
// close content blocks in order
if (has_thinking) {
// Anthropic API requires a signature_delta before closing thinking blocks
// We use an empty signature since we can't generate a cryptographic signature for local models
events.push_back({
{"event", "content_block_delta"},
{"data", {
{"type", "content_block_delta"},
{"index", thinking_block_index},
{"delta", {
{"type", "signature_delta"},
{"signature", ""}
}}
}}
});
events.push_back({
{"event", "content_block_stop"},
{"data", {
{"type", "content_block_stop"},
{"index", thinking_block_index}
}}
});
}
if (has_text) {
events.push_back({
{"event", "content_block_stop"},
{"data", {
{"type", "content_block_stop"},
{"index", 0}
{"index", text_block_index}
}}
});
}
for (size_t i = 0; i < num_tool_calls; i++) {
size_t content_block_index = (has_text ? 1 : 0) + i;
size_t content_block_index = (has_thinking ? 1 : 0) + (has_text ? 1 : 0) + i;
events.push_back({
{"event", "content_block_stop"},
{"data", {
@ -1154,11 +1225,10 @@ json server_task_result_rerank::to_json() {
json server_task_result_cmpl_partial::to_json_anthropic() {
json events = json::array();
bool first = (n_decoded == 1);
bool text_block_started = false;
// use member variables to track block state across streaming calls
// (anthropic_thinking_block_started, anthropic_text_block_started)
if (first) {
text_block_started = false;
events.push_back({
{"event", "message_start"},
{"data", {
@ -1180,28 +1250,69 @@ json server_task_result_cmpl_partial::to_json_anthropic() {
});
}
// content block indices: thinking (0) -> text (0 or 1) -> tool_use (n+)
size_t thinking_block_index = 0;
// use anthropic_has_reasoning (set in update()) to know if ANY reasoning was generated
size_t text_block_index = anthropic_has_reasoning ? 1 : 0;
// use local copies of streaming state (copied from task_result_state in update())
// these reflect the state BEFORE this chunk was processed
bool thinking_started = anthropic_thinking_block_started;
bool text_started = anthropic_text_block_started;
for (const auto & diff : oaicompat_msg_diffs) {
if (!diff.content_delta.empty()) {
if (!text_block_started) {
// handle thinking/reasoning content
if (!diff.reasoning_content_delta.empty()) {
if (!thinking_started) {
events.push_back({
{"event", "content_block_start"},
{"data", {
{"type", "content_block_start"},
{"index", 0},
{"index", thinking_block_index},
{"content_block", {
{"type", "text"},
{"text", ""}
{"type", "thinking"},
{"thinking", ""}
}}
}}
});
text_block_started = true;
thinking_started = true;
}
events.push_back({
{"event", "content_block_delta"},
{"data", {
{"type", "content_block_delta"},
{"index", 0},
{"index", thinking_block_index},
{"delta", {
{"type", "thinking_delta"},
{"thinking", diff.reasoning_content_delta}
}}
}}
});
}
// handle regular text content
if (!diff.content_delta.empty()) {
if (!text_started) {
events.push_back({
{"event", "content_block_start"},
{"data", {
{"type", "content_block_start"},
{"index", text_block_index},
{"content_block", {
{"type", "text"},
{"text", ""}
}}
}}
});
text_started = true;
}
events.push_back({
{"event", "content_block_delta"},
{"data", {
{"type", "content_block_delta"},
{"index", text_block_index},
{"delta", {
{"type", "text_delta"},
{"text", diff.content_delta}
@ -1210,8 +1321,10 @@ json server_task_result_cmpl_partial::to_json_anthropic() {
});
}
// handle tool calls
if (diff.tool_call_index != std::string::npos) {
size_t content_block_index = (text_block_started ? 1 : 0) + diff.tool_call_index;
// use anthropic_has_reasoning for thinking block count (persists across calls)
size_t content_block_index = (anthropic_has_reasoning ? 1 : 0) + (text_started ? 1 : 0) + diff.tool_call_index;
if (!diff.tool_call_delta.name.empty()) {
events.push_back({

View File

@ -96,6 +96,10 @@ struct task_result_state {
std::string generated_text; // append new chunks of generated text here
std::vector<std::string> generated_tool_call_ids;
// for Anthropic API streaming: track content block state across chunks
bool anthropic_thinking_block_started = false;
bool anthropic_text_block_started = false;
task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}
@ -337,6 +341,12 @@ struct server_task_result_cmpl_partial : server_task_result {
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
bool is_updated = false;
// for Anthropic API: track if any reasoning content has been generated
bool anthropic_has_reasoning = false;
// Streaming state copied from task_result_state for this chunk
bool anthropic_thinking_block_started = false;
bool anthropic_text_block_started = false;
virtual bool is_stop() override {
return false; // in stream mode, partial responses are not considered stop
}
@ -346,6 +356,22 @@ struct server_task_result_cmpl_partial : server_task_result {
virtual void update(task_result_state & state) override {
is_updated = true;
state.update_chat_msg(content, true, oaicompat_msg_diffs);
// track if the accumulated message has any reasoning content
anthropic_has_reasoning = !state.chat_msg.reasoning_content.empty();
// Copy current state for use in to_json_anthropic() (reflects state BEFORE this chunk)
anthropic_thinking_block_started = state.anthropic_thinking_block_started;
anthropic_text_block_started = state.anthropic_text_block_started;
// Pre-compute state updates based on diffs (for next chunk)
for (const auto & diff : oaicompat_msg_diffs) {
if (!diff.reasoning_content_delta.empty() && !state.anthropic_thinking_block_started) {
state.anthropic_thinking_block_started = true;
}
if (!diff.content_delta.empty() && !state.anthropic_text_block_started) {
state.anthropic_text_block_started = true;
}
}
}
json to_json_non_oaicompat();

View File

@ -805,3 +805,92 @@ def test_anthropic_vs_openai_different_response_format():
assert "input_tokens" in anthropic_res.body["usage"]
assert "completion_tokens" in openai_res.body["usage"]
assert "output_tokens" in anthropic_res.body["usage"]
# Extended thinking tests with reasoning models
@pytest.mark.slow
@pytest.mark.parametrize("stream", [False, True])
def test_anthropic_thinking_with_reasoning_model(stream):
"""Test that thinking content blocks are properly returned for reasoning models"""
global server
server = ServerProcess()
server.model_hf_repo = "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF"
server.model_hf_file = "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf"
server.reasoning_format = "deepseek"
server.jinja = True
server.n_ctx = 8192
server.n_predict = 1024
server.server_port = 8084
server.start(timeout_seconds=600) # large model needs time to download
if stream:
res = server.make_stream_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 1024,
"thinking": {
"type": "enabled",
"budget_tokens": 500
},
"messages": [
{"role": "user", "content": "What is 2+2?"}
],
"stream": True
})
events = list(res)
# should have thinking content block events
thinking_starts = [e for e in events if
e.get("type") == "content_block_start" and
e.get("content_block", {}).get("type") == "thinking"]
assert len(thinking_starts) > 0, "Should have thinking content_block_start event"
assert thinking_starts[0]["index"] == 0, "Thinking block should be at index 0"
# should have thinking_delta events
thinking_deltas = [e for e in events if
e.get("type") == "content_block_delta" and
e.get("delta", {}).get("type") == "thinking_delta"]
assert len(thinking_deltas) > 0, "Should have thinking_delta events"
# should have signature_delta event before thinking block closes (Anthropic API requirement)
signature_deltas = [e for e in events if
e.get("type") == "content_block_delta" and
e.get("delta", {}).get("type") == "signature_delta"]
assert len(signature_deltas) > 0, "Should have signature_delta event for thinking block"
# should have text block after thinking
text_starts = [e for e in events if
e.get("type") == "content_block_start" and
e.get("content_block", {}).get("type") == "text"]
assert len(text_starts) > 0, "Should have text content_block_start event"
assert text_starts[0]["index"] == 1, "Text block should be at index 1 (after thinking)"
else:
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 1024,
"thinking": {
"type": "enabled",
"budget_tokens": 500
},
"messages": [
{"role": "user", "content": "What is 2+2?"}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
content = res.body["content"]
assert len(content) >= 2, "Should have at least thinking and text blocks"
# first block should be thinking
thinking_blocks = [b for b in content if b.get("type") == "thinking"]
assert len(thinking_blocks) > 0, "Should have thinking content block"
assert "thinking" in thinking_blocks[0], "Thinking block should have 'thinking' field"
assert len(thinking_blocks[0]["thinking"]) > 0, "Thinking content should not be empty"
assert "signature" in thinking_blocks[0], "Thinking block should have 'signature' field (Anthropic API requirement)"
# should also have text block
text_blocks = [b for b in content if b.get("type") == "text"]
assert len(text_blocks) > 0, "Should have text content block"