Merge branch 'ggml-org:master' into Kimi-Linear
This commit is contained in:
commit
40f6118192
107
common/arg.cpp
107
common/arg.cpp
|
|
@ -854,6 +854,54 @@ bool common_arg_utils::is_autoy(const std::string & value) {
|
|||
return value == "auto" || value == "-1";
|
||||
}
|
||||
|
||||
// Simple CSV parser that handles quoted fields and escaped quotes
|
||||
// example:
|
||||
// input: value1,"value, with, commas","value with ""escaped"" quotes",value4
|
||||
// output: [value1] [value, with, commas] [value with "escaped" quotes] [value4]
|
||||
static std::vector<std::string> parse_csv_row(const std::string& input) {
|
||||
std::vector<std::string> fields;
|
||||
std::string field;
|
||||
bool in_quotes = false;
|
||||
|
||||
for (size_t i = 0; i < input.length(); ++i) {
|
||||
char ch = input[i];
|
||||
|
||||
if (ch == '"') {
|
||||
if (!in_quotes) {
|
||||
// start of quoted field (only valid if at beginning of field)
|
||||
if (!field.empty()) {
|
||||
// quote appeared in middle of unquoted field, treat as literal
|
||||
field += '"';
|
||||
} else {
|
||||
in_quotes = true; // start
|
||||
}
|
||||
} else {
|
||||
if (i + 1 < input.length() && input[i + 1] == '"') {
|
||||
// escaped quote: ""
|
||||
field += '"';
|
||||
++i; // skip the next quote
|
||||
} else {
|
||||
in_quotes = false; // end
|
||||
}
|
||||
}
|
||||
} else if (ch == ',') {
|
||||
if (in_quotes) {
|
||||
field += ',';
|
||||
} else {
|
||||
fields.push_back(std::move(field));
|
||||
field.clear();
|
||||
}
|
||||
} else {
|
||||
field += ch;
|
||||
}
|
||||
}
|
||||
|
||||
// Add the last field
|
||||
fields.push_back(std::move(field));
|
||||
|
||||
return fields;
|
||||
}
|
||||
|
||||
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
|
||||
// per-example default params
|
||||
// we define here to make sure it's included in llama-gen-docs
|
||||
|
|
@ -1250,7 +1298,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
{"--in-file"}, "FNAME",
|
||||
"an input file (use comma-separated values to specify multiple files)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
for (const auto & item : string_split<std::string>(value, ',')) {
|
||||
for (const auto & item : parse_csv_row(value)) {
|
||||
std::ifstream file(item);
|
||||
if (!file) {
|
||||
throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str()));
|
||||
|
|
@ -2002,7 +2050,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
{"--image", "--audio"}, "FILE",
|
||||
"path to an image or audio file. use with multimodal models, use comma-separated values for multiple files\n",
|
||||
[](common_params & params, const std::string & value) {
|
||||
for (const auto & item : string_split<std::string>(value, ',')) {
|
||||
for (const auto & item : parse_csv_row(value)) {
|
||||
params.image.emplace_back(item);
|
||||
}
|
||||
}
|
||||
|
|
@ -2259,37 +2307,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
));
|
||||
add_opt(common_arg(
|
||||
{"--override-kv"}, "KEY=TYPE:VALUE,...",
|
||||
"advanced option to override model metadata by key. to specify multiple overrides, either use comma-separated or repeat this argument.\n"
|
||||
"advanced option to override model metadata by key. to specify multiple overrides, either use comma-separated values.\n"
|
||||
"types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false,tokenizer.ggml.add_eos_token=bool:false",
|
||||
[](common_params & params, const std::string & value) {
|
||||
std::vector<std::string> kv_overrides;
|
||||
|
||||
std::string current;
|
||||
bool escaping = false;
|
||||
|
||||
for (const char c : value) {
|
||||
if (escaping) {
|
||||
current.push_back(c);
|
||||
escaping = false;
|
||||
} else if (c == '\\') {
|
||||
escaping = true;
|
||||
} else if (c == ',') {
|
||||
kv_overrides.push_back(current);
|
||||
current.clear();
|
||||
} else {
|
||||
current.push_back(c);
|
||||
}
|
||||
}
|
||||
|
||||
if (escaping) {
|
||||
current.push_back('\\');
|
||||
}
|
||||
|
||||
kv_overrides.push_back(current);
|
||||
|
||||
for (const auto & kv_override : kv_overrides) {
|
||||
if (!string_parse_kv_override(kv_override.c_str(), params.kv_overrides)) {
|
||||
throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", kv_override.c_str()));
|
||||
for (const auto & item : parse_csv_row(value)) {
|
||||
if (!string_parse_kv_override(item.c_str(), params.kv_overrides)) {
|
||||
throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", item.c_str()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2306,7 +2329,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
{"--lora"}, "FNAME",
|
||||
"path to LoRA adapter (use comma-separated values to load multiple adapters)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
for (const auto & item : string_split<std::string>(value, ',')) {
|
||||
for (const auto & item : parse_csv_row(value)) {
|
||||
params.lora_adapters.push_back({ item, 1.0, "", "", nullptr });
|
||||
}
|
||||
}
|
||||
|
|
@ -2317,7 +2340,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
"path to LoRA adapter with user defined scaling (format: FNAME:SCALE,...)\n"
|
||||
"note: use comma-separated values",
|
||||
[](common_params & params, const std::string & value) {
|
||||
for (const auto & item : string_split<std::string>(value, ',')) {
|
||||
for (const auto & item : parse_csv_row(value)) {
|
||||
auto parts = string_split<std::string>(item, ':');
|
||||
if (parts.size() != 2) {
|
||||
throw std::invalid_argument("lora-scaled format: FNAME:SCALE");
|
||||
|
|
@ -2331,7 +2354,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
{"--control-vector"}, "FNAME",
|
||||
"add a control vector\nnote: use comma-separated values to add multiple control vectors",
|
||||
[](common_params & params, const std::string & value) {
|
||||
for (const auto & item : string_split<std::string>(value, ',')) {
|
||||
for (const auto & item : parse_csv_row(value)) {
|
||||
params.control_vectors.push_back({ 1.0f, item, });
|
||||
}
|
||||
}
|
||||
|
|
@ -2341,7 +2364,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
"add a control vector with user defined scaling SCALE\n"
|
||||
"note: use comma-separated values (format: FNAME:SCALE,...)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
for (const auto & item : string_split<std::string>(value, ',')) {
|
||||
for (const auto & item : parse_csv_row(value)) {
|
||||
auto parts = string_split<std::string>(item, ':');
|
||||
if (parts.size() != 2) {
|
||||
throw std::invalid_argument("control-vector-scaled format: FNAME:SCALE");
|
||||
|
|
@ -2439,7 +2462,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
{"--context-file"}, "FNAME",
|
||||
"file to load context from (use comma-separated values to specify multiple files)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
for (const auto & item : string_split<std::string>(value, ',')) {
|
||||
for (const auto & item : parse_csv_row(value)) {
|
||||
std::ifstream file(item, std::ios::binary);
|
||||
if (!file) {
|
||||
throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str()));
|
||||
|
|
@ -2675,9 +2698,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING"));
|
||||
add_opt(common_arg(
|
||||
{"--api-key"}, "KEY",
|
||||
"API key to use for authentication (default: none)",
|
||||
"API key to use for authentication, multiple keys can be provided as a comma-separated list (default: none)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.api_keys.push_back(value);
|
||||
for (const auto & key : parse_csv_row(value)) {
|
||||
if (!key.empty()) {
|
||||
params.api_keys.push_back(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_API_KEY"));
|
||||
add_opt(common_arg(
|
||||
|
|
@ -2691,7 +2718,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
std::string key;
|
||||
while (std::getline(key_file, key)) {
|
||||
if (!key.empty()) {
|
||||
params.api_keys.push_back(key);
|
||||
params.api_keys.push_back(key);
|
||||
}
|
||||
}
|
||||
key_file.close();
|
||||
|
|
@ -2713,7 +2740,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSL_CERT_FILE"));
|
||||
add_opt(common_arg(
|
||||
{"--chat-template-kwargs"}, "STRING",
|
||||
string_format("sets additional params for the json template parser"),
|
||||
"sets additional params for the json template parser, must be a valid json object string, e.g. '{\"key1\":\"value1\",\"key2\":\"value2\"}'",
|
||||
[](common_params & params, const std::string & value) {
|
||||
auto parsed = json::parse(value);
|
||||
for (const auto & item : parsed.items()) {
|
||||
|
|
|
|||
|
|
@ -1963,7 +1963,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context & ctx, ggml_tensor *
|
|||
acl_tensor_ptr acl_weight_tensor;
|
||||
|
||||
// Only check env once.
|
||||
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
|
||||
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
|
||||
if (weight_to_nz && is_matmul_weight(weight)) {
|
||||
acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ const ggml_cann_device_info & ggml_cann_info();
|
|||
void ggml_cann_set_device(int32_t device);
|
||||
int32_t ggml_cann_get_device();
|
||||
|
||||
std::optional<std::string> get_env(const std::string & name);
|
||||
std::optional<std::string> get_env_as_lowercase(const std::string & name);
|
||||
bool parse_bool(const std::string & value);
|
||||
int parse_integer(const std::string & value);
|
||||
|
||||
|
|
|
|||
|
|
@ -105,10 +105,10 @@ int32_t ggml_cann_get_device() {
|
|||
}
|
||||
|
||||
/**
|
||||
* @brief Get the value of the specified environment variable (name).
|
||||
* @brief Get the value of the specified environment variable (name) as lowercase.
|
||||
* if not empty, return a std::string object
|
||||
*/
|
||||
std::optional<std::string> get_env(const std::string & name) {
|
||||
std::optional<std::string> get_env_as_lowercase(const std::string & name) {
|
||||
const char * val = std::getenv(name.c_str());
|
||||
if (!val) {
|
||||
return std::nullopt;
|
||||
|
|
@ -259,7 +259,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
|
|||
* @param device The device ID to associate with this buffer pool.
|
||||
*/
|
||||
explicit ggml_cann_pool_buf_prio(int device) : device(device) {
|
||||
disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
|
||||
disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -452,7 +452,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
|
|||
* @param device The device ID to associate with this buffer pool.
|
||||
*/
|
||||
explicit ggml_cann_pool_buf(int device) : device(device) {
|
||||
disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
|
||||
disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -764,7 +764,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
|
|||
* @return A unique pointer to the created CANN pool.
|
||||
*/
|
||||
std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(int device) {
|
||||
std::string mem_pool_type = get_env("GGML_CANN_MEM_POOL").value_or("");
|
||||
std::string mem_pool_type = get_env_as_lowercase("GGML_CANN_MEM_POOL").value_or("");
|
||||
|
||||
if (mem_pool_type == "prio") {
|
||||
GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device);
|
||||
|
|
@ -1217,7 +1217,7 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|||
// Why aclrtSynchronizeDevice?
|
||||
|
||||
// Only check env once.
|
||||
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
|
||||
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
|
||||
if (!need_transform(tensor->type)) {
|
||||
ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));
|
||||
if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
|
||||
|
|
@ -1442,7 +1442,7 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_t
|
|||
int64_t ne0 = tensor->ne[0];
|
||||
|
||||
// Only check env once.
|
||||
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
|
||||
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
|
||||
|
||||
// last line must bigger than 32, because every single op deal at
|
||||
// least 32 bytes.
|
||||
|
|
@ -2136,7 +2136,7 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
|
|||
#endif // USE_ACL_GRAPH
|
||||
// Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
|
||||
// With the use of CANN graphs, the execution will be performed by the graph launch.
|
||||
static bool opt_fusion = parse_bool(get_env("GGML_CANN_OPERATOR_FUSION").value_or(""));
|
||||
static bool opt_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or(""));
|
||||
|
||||
if (!use_cann_graph || cann_graph_capture_required) {
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
|
|
@ -2201,7 +2201,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
|
|||
#ifdef USE_ACL_GRAPH
|
||||
bool use_cann_graph = true;
|
||||
|
||||
static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
|
||||
static bool prefill_use_graph = parse_bool(get_env_as_lowercase("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
|
||||
if (!prefill_use_graph) {
|
||||
// Do not use acl_graph for prefill.
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
|
|
|
|||
|
|
@ -1036,7 +1036,7 @@ struct ggml_tensor_extra_gpu {
|
|||
#define USE_CUDA_GRAPH
|
||||
#endif
|
||||
|
||||
struct ggml_graph_node_properties {
|
||||
struct ggml_cuda_graph_node_properties {
|
||||
void * node_address;
|
||||
ggml_op node_op;
|
||||
int64_t ne[GGML_MAX_DIMS];
|
||||
|
|
@ -1061,11 +1061,25 @@ struct ggml_cuda_graph {
|
|||
std::vector<cudaGraphNode_t> nodes;
|
||||
bool disable_due_to_gpu_arch = false;
|
||||
bool disable_due_to_too_many_updates = false;
|
||||
bool disable_due_to_failed_graph_capture = false;
|
||||
int number_consecutive_updates = 0;
|
||||
bool cuda_graphs_enabled = false;
|
||||
std::vector<ggml_graph_node_properties> ggml_graph_properties;
|
||||
std::vector<ggml_graph_node_properties> extraneous_srcs_properties;
|
||||
std::vector<ggml_cuda_graph_node_properties> props;
|
||||
|
||||
void record_update(bool use_graph, bool update_required) {
|
||||
if (use_graph && update_required) {
|
||||
number_consecutive_updates++;
|
||||
} else {
|
||||
number_consecutive_updates = 0;
|
||||
}
|
||||
if (number_consecutive_updates >= 4) {
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
|
||||
disable_due_to_too_many_updates = true;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_enabled() const {
|
||||
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
||||
return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -2853,9 +2853,9 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
|
|||
}
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
|
||||
bool use_cuda_graph) {
|
||||
static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
||||
|
||||
bool use_cuda_graph = true;
|
||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||
|
||||
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
|
||||
|
|
@ -2915,41 +2915,41 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
|
|||
return use_cuda_graph;
|
||||
}
|
||||
|
||||
static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
|
||||
graph_node_properties->node_address = node->data;
|
||||
graph_node_properties->node_op = node->op;
|
||||
static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
|
||||
props->node_address = node->data;
|
||||
props->node_op = node->op;
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
graph_node_properties->ne[i] = node->ne[i];
|
||||
graph_node_properties->nb[i] = node->nb[i];
|
||||
props->ne[i] = node->ne[i];
|
||||
props->nb[i] = node->nb[i];
|
||||
}
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
|
||||
props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
|
||||
}
|
||||
memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
|
||||
memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
|
||||
}
|
||||
|
||||
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
|
||||
if (node->data != graph_node_properties->node_address &&
|
||||
static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
|
||||
if (node->data != props->node_address &&
|
||||
node->op != GGML_OP_VIEW) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node->op != graph_node_properties->node_op) {
|
||||
if (node->op != props->node_op) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
if (node->ne[i] != graph_node_properties->ne[i]) {
|
||||
if (node->ne[i] != props->ne[i]) {
|
||||
return false;
|
||||
}
|
||||
if (node->nb[i] != graph_node_properties->nb[i]) {
|
||||
if (node->nb[i] != props->nb[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (node->src[i] &&
|
||||
node->src[i]->data != graph_node_properties->src_address[i] &&
|
||||
node->src[i]->data != props->src_address[i] &&
|
||||
node->op != GGML_OP_VIEW
|
||||
) {
|
||||
return false;
|
||||
|
|
@ -2957,56 +2957,55 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
|
|||
}
|
||||
|
||||
if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
|
||||
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
|
||||
memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
|
||||
static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
|
||||
|
||||
bool cuda_graph_update_required = false;
|
||||
bool res = false;
|
||||
|
||||
if (cuda_ctx->cuda_graph->instance == nullptr) {
|
||||
cuda_graph_update_required = true;
|
||||
res = true;
|
||||
}
|
||||
|
||||
// Check if the graph size has changed
|
||||
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
|
||||
cuda_graph_update_required = true;
|
||||
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes + cgraph->n_leafs);
|
||||
if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
|
||||
res = true;
|
||||
cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
|
||||
}
|
||||
|
||||
// Loop over nodes in GGML graph to determine if CUDA graph update is required
|
||||
// and store properties to allow this comparison for the next token
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
bool has_matching_properties = true;
|
||||
|
||||
if (!cuda_graph_update_required) {
|
||||
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
|
||||
bool props_match = true;
|
||||
if (!res) {
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]);
|
||||
}
|
||||
if (!has_matching_properties) {
|
||||
cuda_graph_update_required = true;
|
||||
if (!props_match) {
|
||||
res = true;
|
||||
}
|
||||
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
|
||||
ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]);
|
||||
}
|
||||
|
||||
for (int i = 0; i < cgraph->n_leafs; i++) {
|
||||
bool has_matching_properties = true;
|
||||
if (!cuda_graph_update_required) {
|
||||
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->leafs[i], &cuda_ctx->cuda_graph->ggml_graph_properties[cgraph->n_nodes + i]);
|
||||
bool props_match= true;
|
||||
if (!res) {
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]);
|
||||
}
|
||||
if (!has_matching_properties) {
|
||||
cuda_graph_update_required = true;
|
||||
if (!props_match) {
|
||||
res = true;
|
||||
}
|
||||
set_ggml_graph_node_properties(cgraph->leafs[i], &cuda_ctx->cuda_graph->ggml_graph_properties[cgraph->n_nodes + i]);
|
||||
ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
|
||||
}
|
||||
|
||||
return cuda_graph_update_required;
|
||||
return res;
|
||||
}
|
||||
|
||||
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||
static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||
|
||||
#if CUDART_VERSION >= 12000
|
||||
cudaGraphExecUpdateResultInfo result_info;
|
||||
|
|
@ -3237,10 +3236,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|||
return false;
|
||||
}
|
||||
|
||||
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
||||
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
|
||||
static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) {
|
||||
bool graph_evaluated_or_captured = false;
|
||||
|
||||
// flag used to determine whether it is an integrated_gpu
|
||||
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
|
||||
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
|
||||
|
||||
ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context();
|
||||
bool is_concurrent_event_active = false;
|
||||
|
|
@ -3710,7 +3710,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|||
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
||||
}
|
||||
if (cuda_graph_update_required) { // Update graph executable
|
||||
update_cuda_graph_executable(cuda_ctx);
|
||||
ggml_cuda_graph_update_executable(cuda_ctx);
|
||||
}
|
||||
// Launch graph
|
||||
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
|
||||
|
|
@ -3720,43 +3720,25 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|||
}
|
||||
}
|
||||
|
||||
static bool ggml_cuda_set_cuda_graph_enabled(ggml_backend_cuda_context * cuda_ctx) {
|
||||
static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) {
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
||||
|
||||
// Objects required for CUDA Graph
|
||||
if (cuda_ctx->cuda_graph == nullptr) {
|
||||
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
|
||||
}
|
||||
|
||||
bool use_cuda_graph = true;
|
||||
|
||||
if (cuda_ctx->cuda_graph->graph == nullptr) {
|
||||
if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
|
||||
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
|
||||
// or previous graph capture failure.
|
||||
// Also disable for multi-gpu for now. TO DO investigate
|
||||
if (disable_cuda_graphs_due_to_env
|
||||
|| cuda_ctx->cuda_graph->disable_due_to_gpu_arch
|
||||
|| cuda_ctx->cuda_graph->disable_due_to_too_many_updates
|
||||
|| cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
|
||||
use_cuda_graph = false;
|
||||
}
|
||||
|
||||
cuda_ctx->cuda_graph->cuda_graphs_enabled = use_cuda_graph;
|
||||
return cuda_ctx->cuda_graph->is_enabled();
|
||||
#else
|
||||
bool use_cuda_graph = false;
|
||||
return false;
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
return use_cuda_graph;
|
||||
}
|
||||
|
||||
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
|
|
@ -3767,30 +3749,14 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
|||
bool use_cuda_graph = false;
|
||||
bool cuda_graph_update_required = false;
|
||||
|
||||
// graph_optimize calls set_cuda_graph_enabled, in-case it not called (i.e. graph_compute is directly called)
|
||||
// we call it here instead.
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx);
|
||||
use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
|
||||
|
||||
if (use_cuda_graph) {
|
||||
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
|
||||
if (cuda_ctx->cuda_graph->is_enabled()) {
|
||||
cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
|
||||
use_cuda_graph = ggml_cuda_graph_check_compability(cgraph);
|
||||
|
||||
use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
|
||||
|
||||
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
|
||||
if (use_cuda_graph && cuda_graph_update_required) {
|
||||
cuda_ctx->cuda_graph->number_consecutive_updates++;
|
||||
} else {
|
||||
cuda_ctx->cuda_graph->number_consecutive_updates = 0;
|
||||
}
|
||||
|
||||
if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
|
||||
cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
|
||||
cuda_ctx->cuda_graph->cuda_graphs_enabled = false;
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
|
||||
#endif
|
||||
}
|
||||
cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required);
|
||||
}
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
|
|
@ -3804,9 +3770,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
|||
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
|
||||
}
|
||||
|
||||
bool graph_evaluated_or_captured = false;
|
||||
|
||||
evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
|
||||
ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required);
|
||||
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
|
@ -3839,7 +3803,7 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
|
|||
static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
|
||||
|
||||
const bool use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx);
|
||||
const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
|
||||
|
||||
static bool enable_graph_optimization = [] {
|
||||
const char * env = getenv("GGML_CUDA_GRAPH_OPT");
|
||||
|
|
|
|||
|
|
@ -34,13 +34,11 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
// CUDA_GRAPHS_DISABLED
|
||||
((ncols > 65536) &&
|
||||
((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
|
||||
ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
|
||||
ctx.cuda_graph->disable_due_to_failed_graph_capture)) ||
|
||||
ctx.cuda_graph->is_enabled())) ||
|
||||
// CUDA_GRAPHS ENABLED
|
||||
((ncols > 32768) &&
|
||||
!((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
|
||||
ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
|
||||
ctx.cuda_graph->disable_due_to_failed_graph_capture))) {
|
||||
ctx.cuda_graph->is_enabled()))) {
|
||||
#else
|
||||
(ncols > 65536)) {
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
|
|
|||
|
|
@ -333,6 +333,28 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
|
|||
}
|
||||
|
||||
if (amd_wmma_available(cc)) {
|
||||
// RDNA 4 is consistently worse on rocblas
|
||||
// https://github.com/ggml-org/llama.cpp/pull/18537#issuecomment-3706422301
|
||||
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
|
||||
// High expert counts almost always better on MMQ
|
||||
// due to a large amount of graph splits
|
||||
// https://github.com/ggml-org/llama.cpp/pull/18202
|
||||
if (n_experts >= 64) {
|
||||
return true;
|
||||
}
|
||||
|
||||
switch (type) {
|
||||
// These quants are really bad on MMQ
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
// These quants are usually worse but not always
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
return ne11 <= 128;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ __global__ void __launch_bounds__(splitD, 1)
|
|||
#endif // __clang__
|
||||
|
||||
// assumes as many threads as d_state
|
||||
template <int splitH, int d_state>
|
||||
template <int c_factor, int d_state>
|
||||
__global__ void __launch_bounds__(d_state, 1)
|
||||
ssm_scan_f32_group(
|
||||
const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
|
||||
|
|
@ -125,20 +125,25 @@ __global__ void __launch_bounds__(d_state, 1)
|
|||
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
|
||||
const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
|
||||
|
||||
const int head_idx = (blockIdx.x * splitH) / d_head;
|
||||
const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
|
||||
const int seq_idx = blockIdx.y;
|
||||
const int warp = threadIdx.x / WARP_SIZE;
|
||||
const int lane = threadIdx.x % WARP_SIZE;
|
||||
const int warp_idx = blockIdx.x * c_factor + warp;
|
||||
|
||||
const int head_idx = warp_idx / d_head;
|
||||
const int head_off = (warp_idx % d_head) * sizeof(float);
|
||||
const int seq_idx = blockIdx.y;
|
||||
|
||||
const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
|
||||
|
||||
const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
|
||||
const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
|
||||
const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
|
||||
const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1);
|
||||
const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
|
||||
const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
|
||||
float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH;
|
||||
float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
|
||||
// TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase
|
||||
const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
|
||||
const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float)));
|
||||
const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
|
||||
const float * A_warp = (const float *) ((const char *) src3 + head_idx * src3_nb1);
|
||||
const float * B_warp = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
|
||||
const float * C_warp = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
|
||||
float * y_warp = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx;
|
||||
float * s_warp = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
|
||||
|
||||
// strides across n_seq_tokens
|
||||
const int stride_x = src1_nb2 / sizeof(float);
|
||||
|
|
@ -147,80 +152,42 @@ __global__ void __launch_bounds__(d_state, 1)
|
|||
const int stride_C = src5_nb2 / sizeof(float);
|
||||
const int stride_y = n_head * d_head;
|
||||
|
||||
float state[splitH];
|
||||
// for the parallel accumulation
|
||||
__shared__ float stateC[splitH * d_state];
|
||||
float state[c_factor];
|
||||
float state_sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < splitH; j++) {
|
||||
state[j] = s0_block[j * d_state + threadIdx.x];
|
||||
for (int j = 0; j < c_factor; j++) {
|
||||
state[j] = s0_warp[WARP_SIZE * j + lane];
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < n_tok; i++) {
|
||||
// TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
|
||||
// TODO: only calculate B and C once per head group
|
||||
// NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
|
||||
float dt_soft_plus = dt_block[i * stride_dt];
|
||||
if (dt_soft_plus <= 20.0f) {
|
||||
dt_soft_plus = log1pf(expf(dt_soft_plus));
|
||||
}
|
||||
const float dA = expf(dt_soft_plus * A_block[0]);
|
||||
const float B = B_block[i * stride_B + threadIdx.x];
|
||||
const float C = C_block[i * stride_C + threadIdx.x];
|
||||
// NOTE: dt_soft_plus, dA and x_dt have the same value for a warp here.
|
||||
// Recalculation is intentional; sharing via shuffles/smem proved slower due to sync overhead.
|
||||
const float dt_soft_plus = (dt_warp[i * stride_dt] <= 20.0f ? log1pf(expf(dt_warp[i * stride_dt])) : dt_warp[i * stride_dt]);
|
||||
|
||||
// across d_head
|
||||
state_sum = 0.0f;
|
||||
const float dA = expf(dt_soft_plus * A_warp[0]);
|
||||
const float x_dt = x_warp[i * stride_x] * dt_soft_plus;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < splitH; j++) {
|
||||
const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
|
||||
|
||||
state[j] = (state[j] * dA) + (B * x_dt);
|
||||
|
||||
stateC[j * d_state + threadIdx.x] = state[j] * C;
|
||||
for (int j = 0; j < c_factor; j++) {
|
||||
const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane];
|
||||
const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane];
|
||||
state[j] = (state[j] * dA) + (B_val * x_dt);
|
||||
state_sum += state[j] * C_val;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
// parallel accumulation for output
|
||||
state_sum = warp_reduce_sum(state_sum);
|
||||
|
||||
// parallel accumulation for stateC
|
||||
// TODO: simplify
|
||||
{
|
||||
static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
|
||||
static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
|
||||
|
||||
// reduce until w matches the warp size
|
||||
// TODO: does this work even when the physical warp size is 64?
|
||||
#pragma unroll
|
||||
for (int w = d_state; w > WARP_SIZE; w >>= 1) {
|
||||
// (assuming there are d_state threads)
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
|
||||
// TODO: check for bank conflicts
|
||||
const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
|
||||
stateC[k] += stateC[k + (w >> 1)];
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
static_assert(splitH >= d_state / WARP_SIZE);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
|
||||
float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
|
||||
y = warp_reduce_sum(y);
|
||||
|
||||
// store the above accumulations
|
||||
if (threadIdx.x % WARP_SIZE == 0) {
|
||||
const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE);
|
||||
y_block[i * stride_y + k] = y;
|
||||
}
|
||||
}
|
||||
if (lane == 0) {
|
||||
y_warp[i * stride_y] = state_sum;
|
||||
}
|
||||
}
|
||||
|
||||
// write back the state
|
||||
#pragma unroll
|
||||
for (int j = 0; j < splitH; j++) {
|
||||
s_block[j * d_state + threadIdx.x] = state[j];
|
||||
for (int j = 0; j < c_factor; j++) {
|
||||
s_warp[WARP_SIZE * j + lane] = state[j];
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -231,27 +198,24 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
|
|||
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
|
||||
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
|
||||
cudaStream_t stream) {
|
||||
const int threads = 128;
|
||||
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
|
||||
if (src3_nb1 == sizeof(float)) {
|
||||
// Mamba-2
|
||||
if (d_state == 128) {
|
||||
GGML_ASSERT(d_state % threads == 0);
|
||||
// NOTE: can be any power of two between 4 and 64
|
||||
const int splitH = 16;
|
||||
GGML_ASSERT(head_dim % splitH == 0);
|
||||
const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
|
||||
ssm_scan_f32_group<16, 128><<<blocks, threads, 0, stream>>>(
|
||||
constexpr int threads = 128;
|
||||
constexpr int num_warps = threads/WARP_SIZE;
|
||||
|
||||
const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
|
||||
ssm_scan_f32_group<128/WARP_SIZE, 128><<<blocks, threads, 0, stream>>>(
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
|
||||
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
|
||||
} else if (d_state == 256) { // Falcon-H1
|
||||
const int threads = 256;
|
||||
// NOTE: can be any power of two between 8 and 64
|
||||
const int splitH = 16;
|
||||
GGML_ASSERT(head_dim % splitH == 0);
|
||||
const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
|
||||
ssm_scan_f32_group<16, 256><<<blocks, threads, 0, stream>>>(
|
||||
constexpr int threads = 256;
|
||||
constexpr int num_warps = threads/WARP_SIZE;
|
||||
|
||||
const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
|
||||
ssm_scan_f32_group<256/WARP_SIZE, 256><<<blocks, threads, 0, stream>>>(
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
|
||||
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
|
||||
|
|
@ -260,6 +224,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
|
|||
}
|
||||
} else {
|
||||
// Mamba-1
|
||||
constexpr int threads = 128;
|
||||
GGML_ASSERT(n_head % threads == 0);
|
||||
GGML_ASSERT(head_dim == 1);
|
||||
GGML_ASSERT(n_group == 1);
|
||||
|
|
|
|||
|
|
@ -1773,6 +1773,37 @@ static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * src1 = op->src[1];
|
||||
const struct ggml_tensor * src2 = op->src[2];
|
||||
const struct ggml_tensor * src3 = op->src[3];
|
||||
const struct ggml_tensor * src4 = op->src[4];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
// Check for F16 support only as requested
|
||||
if ((src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_F32) || src1->type != GGML_TYPE_F16 || src2->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src3 && src3->type != GGML_TYPE_F16) { // mask
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src4 && src4->type != GGML_TYPE_F32) { // sinks
|
||||
return false;
|
||||
}
|
||||
|
||||
// For now we support F32 or F16 output as htp backend often converts output on the fly if needed,
|
||||
// but the op implementation writes to F16 or F32.
|
||||
// Let's assume dst can be F32 or F16.
|
||||
if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return opt_experimental;
|
||||
}
|
||||
|
||||
static bool hex_supported_src0_type(ggml_type t) {
|
||||
return t == GGML_TYPE_F32;
|
||||
}
|
||||
|
|
@ -1815,12 +1846,11 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
|
|||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: add support for non-cont tensors
|
||||
if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
|
||||
if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -1836,7 +1866,6 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
|
|||
return false; // typically the lm-head which would be too large for VTCM
|
||||
}
|
||||
|
||||
// if ((src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3])) return false;
|
||||
if ((src1->ne[2] != 1 || src1->ne[3] != 1)) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -1885,21 +1914,10 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session
|
|||
}
|
||||
break;
|
||||
|
||||
case GGML_TYPE_F16:
|
||||
if (!opt_experimental) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: add support for non-cont tensors
|
||||
if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -2060,6 +2078,46 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_set_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0]; // values
|
||||
const struct ggml_tensor * src1 = op->src[1]; // indices
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0]; // values
|
||||
const struct ggml_tensor * src1 = op->src[1]; // indices
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const int32_t * op_params = &op->op_params[0];
|
||||
|
||||
|
|
@ -2154,6 +2212,11 @@ static size_t htp_req_buff_init(htp_tensor *h, dspqueue_buffer * d, const ggml_t
|
|||
d->offset = (uint8_t *) t->data - buf->base;
|
||||
d->size = ggml_nbytes(t);
|
||||
|
||||
if (!d->size) {
|
||||
// Some requests contain srcs where ggml_nbytes() returns 0 but the rest of the op is non-empty
|
||||
d->size = 64;
|
||||
}
|
||||
|
||||
switch (type) {
|
||||
case DSPQBUF_TYPE_DSP_WRITE_CPU_READ:
|
||||
// Flush CPU
|
||||
|
|
@ -2239,6 +2302,17 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu
|
|||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
req->op = HTP_OP_GET_ROWS;
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
template <bool _is_src0_constant>
|
||||
static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
switch (t->op) {
|
||||
|
|
@ -2266,6 +2340,17 @@ static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer *
|
|||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_set_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
req->op = HTP_OP_SET_ROWS;
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
|
||||
|
||||
|
|
@ -2277,6 +2362,11 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
|
|||
supported = true;
|
||||
break;
|
||||
|
||||
case GGML_OP_SCALE:
|
||||
req->op = HTP_OP_SCALE;
|
||||
supported = true;
|
||||
break;
|
||||
|
||||
case GGML_OP_UNARY:
|
||||
if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
|
||||
req->op = HTP_OP_UNARY_SILU;
|
||||
|
|
@ -2331,6 +2421,21 @@ static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs
|
|||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
|
||||
req->op = HTP_OP_FLASH_ATTN_EXT;
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->src3, &bufs[n_bufs], t->src[3], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->src4, &bufs[n_bufs], t->src[4], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
|
||||
auto sess = static_cast<ggml_hexagon_session *>(backend->context);
|
||||
return sess->name.c_str();
|
||||
|
|
@ -2417,6 +2522,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
|||
ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags);
|
||||
break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_SCALE:
|
||||
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
|
|
@ -2439,6 +2545,18 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
|||
ggml_hexagon_dispatch_op<init_rope_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
ggml_hexagon_dispatch_op<init_flash_attn_ext_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
case GGML_OP_SET_ROWS:
|
||||
ggml_hexagon_dispatch_op<init_set_rows_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
case GGML_OP_GET_ROWS:
|
||||
ggml_hexagon_dispatch_op<init_get_rows_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
default:
|
||||
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
|
||||
}
|
||||
|
|
@ -2778,6 +2896,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
|||
break;
|
||||
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_SCALE:
|
||||
supp = ggml_hexagon_supported_unary(sess, op);
|
||||
break;
|
||||
|
||||
|
|
@ -2805,6 +2924,18 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
|||
supp = ggml_hexagon_supported_rope(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
supp = ggml_hexagon_supported_flash_attn_ext(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_SET_ROWS:
|
||||
supp = ggml_hexagon_supported_set_rows(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_GET_ROWS:
|
||||
supp = ggml_hexagon_supported_get_rows(sess, op);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,6 +28,9 @@ add_library(${HTP_LIB} SHARED
|
|||
softmax-ops.c
|
||||
act-ops.c
|
||||
rope-ops.c
|
||||
flash-attn-ops.c
|
||||
set-rows-ops.c
|
||||
get-rows-ops.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
|
|
|
|||
|
|
@ -0,0 +1,566 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
// Dot product of FP32 and FP16 vectors, accumulating to float
|
||||
static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
|
||||
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
|
||||
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
|
||||
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
|
||||
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
|
||||
|
||||
// Load x (fp16)
|
||||
HVX_Vector x_hf = vx[i];
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
|
||||
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
|
||||
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
|
||||
|
||||
// Load x (fp16)
|
||||
HVX_Vector x_hf = vx[i];
|
||||
|
||||
// Zero-out unused elements
|
||||
// Note that we need to clear both x and y because they may contain NANs
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
x_hf = Q6_V_vand_QV(bmask, x_hf);
|
||||
y_hf = Q6_V_vand_QV(bmask, y_hf);
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
|
||||
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
|
||||
|
||||
hvx_vec_store_u(r, 4, rsum);
|
||||
}
|
||||
|
||||
// Dot product of two F16 vectors, accumulating to float
|
||||
static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
|
||||
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
|
||||
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
HVX_Vector y_hf = vy[i];
|
||||
HVX_Vector x_hf = vx[i];
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector y_hf = vy[i];
|
||||
|
||||
// Load x (fp16) and zero-out unused elements
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
|
||||
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
|
||||
hvx_vec_store_u(r, 4, rsum);
|
||||
}
|
||||
|
||||
// MAD: y (F32) += x (F16) * v (float)
|
||||
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
|
||||
const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
|
||||
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
HVX_Vector S = hvx_vec_splat_fp16(s);
|
||||
|
||||
uint32_t i = 0;
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; ++i) {
|
||||
// Multiply x * s -> pair of F32 vectors
|
||||
HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
|
||||
ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2]));
|
||||
ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1]));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
|
||||
|
||||
HVX_Vector xs = Q6_V_lo_W(xs_p);
|
||||
i = 2 * i; // index for ptr_y
|
||||
|
||||
if (nloe >= 32) {
|
||||
ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
|
||||
nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
|
||||
hvx_vec_store_u(&ptr_y[i], nloe * 4, xy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define FLASH_ATTN_BLOCK_SIZE 128
|
||||
|
||||
static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) {
|
||||
const struct htp_tensor * q = &octx->src0;
|
||||
const struct htp_tensor * k = &octx->src1;
|
||||
const struct htp_tensor * v = &octx->src2;
|
||||
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
|
||||
const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
const uint32_t neq0 = q->ne[0];
|
||||
const uint32_t neq1 = q->ne[1];
|
||||
const uint32_t neq2 = q->ne[2];
|
||||
const uint32_t neq3 = q->ne[3];
|
||||
|
||||
const uint32_t nek0 = k->ne[0];
|
||||
const uint32_t nek1 = k->ne[1];
|
||||
const uint32_t nek2 = k->ne[2];
|
||||
const uint32_t nek3 = k->ne[3];
|
||||
|
||||
const uint32_t nev0 = v->ne[0];
|
||||
const uint32_t nev1 = v->ne[1];
|
||||
const uint32_t nev2 = v->ne[2];
|
||||
const uint32_t nev3 = v->ne[3];
|
||||
|
||||
const uint32_t nbq1 = q->nb[1];
|
||||
const uint32_t nbq2 = q->nb[2];
|
||||
const uint32_t nbq3 = q->nb[3];
|
||||
|
||||
const uint32_t nbk1 = k->nb[1];
|
||||
const uint32_t nbk2 = k->nb[2];
|
||||
const uint32_t nbk3 = k->nb[3];
|
||||
|
||||
const uint32_t nbv1 = v->nb[1];
|
||||
const uint32_t nbv2 = v->nb[2];
|
||||
const uint32_t nbv3 = v->nb[3];
|
||||
|
||||
const uint32_t ne1 = dst->ne[1];
|
||||
const uint32_t ne2 = dst->ne[2];
|
||||
const uint32_t ne3 = dst->ne[3];
|
||||
|
||||
const uint32_t nb1 = dst->nb[1];
|
||||
const uint32_t nb2 = dst->nb[2];
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
|
||||
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
|
||||
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
|
||||
|
||||
if (logit_softcap != 0) {
|
||||
scale /= logit_softcap;
|
||||
}
|
||||
|
||||
// total rows in q
|
||||
const uint32_t nr = neq1*neq2*neq3;
|
||||
|
||||
const uint32_t dr = (nr + nth - 1) / nth;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
if (ir0 >= ir1) return;
|
||||
|
||||
dma_queue * dma = octx->ctx->dma[ith];
|
||||
|
||||
const uint32_t DK = nek0;
|
||||
const uint32_t DV = nev0;
|
||||
|
||||
const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
|
||||
const size_t size_q_row_padded = htp_round_up(size_q_row, 128);
|
||||
|
||||
const size_t size_k_row = DK * sizeof(__fp16);
|
||||
const size_t size_v_row = DV * sizeof(__fp16);
|
||||
const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
|
||||
|
||||
const size_t size_k_row_padded = htp_round_up(size_k_row, 128);
|
||||
const size_t size_v_row_padded = htp_round_up(size_v_row, 128);
|
||||
|
||||
const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
const size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
||||
|
||||
// Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
|
||||
uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
|
||||
uint8_t * spad_k = octx->src1_spad.data + octx->src1_spad.size_per_thread * ith;
|
||||
uint8_t * spad_v = octx->src2_spad.data + octx->src2_spad.size_per_thread * ith;
|
||||
uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
|
||||
uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith;
|
||||
|
||||
const uint32_t n_head = neq2;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
for (uint32_t ir = ir0; ir < ir1; ++ir) {
|
||||
const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
|
||||
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
|
||||
const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
|
||||
|
||||
const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3);
|
||||
const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2);
|
||||
|
||||
const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3);
|
||||
const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2);
|
||||
|
||||
// Fetch Q row
|
||||
const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
|
||||
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
|
||||
|
||||
const uint32_t h = iq2; // head index
|
||||
const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
|
||||
|
||||
float S = 0.0f; // sum
|
||||
float M = -INFINITY; // maximum KQ value
|
||||
|
||||
// Clear accumulator
|
||||
float * VKQ32 = (float *) spad_a;
|
||||
memset(VKQ32, 0, DV * sizeof(float));
|
||||
|
||||
const __fp16 * mp_base = NULL;
|
||||
if (mask) {
|
||||
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2);
|
||||
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3);
|
||||
mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
|
||||
}
|
||||
|
||||
const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
|
||||
|
||||
// Prefetch first two blocks
|
||||
for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
|
||||
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
|
||||
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
|
||||
|
||||
// K
|
||||
const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
|
||||
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
|
||||
|
||||
// V
|
||||
const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
|
||||
uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
|
||||
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
|
||||
|
||||
// Mask
|
||||
if (mask) {
|
||||
const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
|
||||
uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
|
||||
// Mask is 1D contiguous for this row
|
||||
dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
|
||||
}
|
||||
}
|
||||
|
||||
const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
|
||||
|
||||
for (uint32_t ib = 0; ib < n_blocks; ++ib) {
|
||||
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
|
||||
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
|
||||
|
||||
// Wait for DMA
|
||||
uint8_t * k_base = dma_queue_pop(dma).dst; // K
|
||||
uint8_t * v_base = dma_queue_pop(dma).dst; // V
|
||||
__fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M
|
||||
|
||||
// Inner loop processing the block from VTCM
|
||||
uint32_t ic = 0;
|
||||
|
||||
// Process in blocks of 32 (VLEN_FP32)
|
||||
for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) {
|
||||
// 1. Compute scores
|
||||
float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
|
||||
for (int j = 0; j < VLEN_FP32; ++j) {
|
||||
const uint32_t cur_ic = ic + j;
|
||||
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
|
||||
if (q->type == HTP_TYPE_F32) {
|
||||
hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
|
||||
} else {
|
||||
hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
|
||||
}
|
||||
}
|
||||
|
||||
HVX_Vector scores = *(HVX_Vector *) scores_arr;
|
||||
|
||||
// 2. Softcap
|
||||
if (logit_softcap != 0.0f) {
|
||||
scores = hvx_vec_tanh_fp32(scores);
|
||||
scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_fp32(logit_softcap));
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
}
|
||||
|
||||
// 3. Mask
|
||||
if (mask) {
|
||||
const __fp16 * mp = m_base + ic;
|
||||
HVX_Vector m_vals_fp16 = *(const HVX_UVector *) mp;
|
||||
|
||||
HVX_Vector one_fp16 = Q6_Vh_vsplat_R(0x3c00);
|
||||
HVX_VectorPair m_vals_fp32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_fp16), one_fp16);
|
||||
|
||||
HVX_Vector m_vals_fp32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_fp32_pair));
|
||||
|
||||
HVX_Vector slope_vec = hvx_vec_splat_fp32(slope);
|
||||
HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_fp32, slope_vec);
|
||||
scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
}
|
||||
|
||||
// 4. Online Softmax Update
|
||||
HVX_Vector v_max = hvx_vec_reduce_max_fp32(scores);
|
||||
float m_block = hvx_vec_get_fp32(v_max);
|
||||
|
||||
float M_old = M;
|
||||
float M_new = (m_block > M) ? m_block : M;
|
||||
M = M_new;
|
||||
|
||||
float ms = expf(M_old - M_new);
|
||||
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
||||
S = S * ms;
|
||||
|
||||
HVX_Vector M_new_vec = hvx_vec_splat_fp32(M_new);
|
||||
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
|
||||
HVX_Vector P = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(scores_shifted));
|
||||
|
||||
HVX_Vector p_sum_vec = hvx_vec_fp32_reduce_sum(P);
|
||||
float p_sum = hvx_vec_get_fp32(p_sum_vec);
|
||||
S += p_sum;
|
||||
|
||||
// 5. Accumulate V
|
||||
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
|
||||
*(HVX_Vector*)p_arr = P;
|
||||
|
||||
for (int j = 0; j < VLEN_FP32; ++j) {
|
||||
const uint32_t cur_ic = ic + j;
|
||||
const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
|
||||
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// Leftover
|
||||
for (; ic < current_block_size; ++ic) {
|
||||
float s_val;
|
||||
const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
|
||||
|
||||
if (q->type == HTP_TYPE_F32) {
|
||||
hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
|
||||
} else {
|
||||
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
|
||||
}
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
s_val = logit_softcap * tanhf(s_val);
|
||||
}
|
||||
|
||||
if (mask) {
|
||||
const float m_val = m_base[ic];
|
||||
s_val += slope * m_val;
|
||||
}
|
||||
|
||||
const float Mold = M;
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
||||
if (s_val > M) {
|
||||
M = s_val;
|
||||
ms = expf(Mold - M);
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
||||
} else {
|
||||
vs = expf(s_val - M);
|
||||
}
|
||||
|
||||
const uint8_t * v_ptr = v_base + ic * size_v_row_padded;
|
||||
|
||||
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
|
||||
|
||||
S = S * ms + vs;
|
||||
}
|
||||
|
||||
// Issue DMA for next+1 block (if exists)
|
||||
if (ib + 2 < n_blocks) {
|
||||
const uint32_t next_ib = ib + 2;
|
||||
const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
|
||||
const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
|
||||
|
||||
// K
|
||||
const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
|
||||
|
||||
// V
|
||||
const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
|
||||
dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
|
||||
|
||||
// Mask
|
||||
if (mask) {
|
||||
const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
|
||||
dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sinks
|
||||
if (sinks) {
|
||||
const float s = ((float *)((char *) sinks->data))[h];
|
||||
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
||||
if (s > M) {
|
||||
ms = expf(M - s);
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
||||
} else {
|
||||
vs = expf(s - M);
|
||||
}
|
||||
|
||||
S = S * ms + vs;
|
||||
}
|
||||
|
||||
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, S_inv);
|
||||
|
||||
// Store result
|
||||
// dst indices
|
||||
const int i1 = iq1;
|
||||
const int i2 = iq2;
|
||||
const int i3 = iq3;
|
||||
|
||||
// dst is permuted
|
||||
uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;
|
||||
|
||||
if (dst->type == HTP_TYPE_F32) {
|
||||
hvx_copy_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
|
||||
} else if (dst->type == HTP_TYPE_F16) {
|
||||
hvx_copy_fp16_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = data;
|
||||
flash_attn_ext_f16_thread(octx, i, n);
|
||||
}
|
||||
|
||||
int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * q = &octx->src0;
|
||||
const struct htp_tensor * k = &octx->src1;
|
||||
const struct htp_tensor * v = &octx->src2;
|
||||
const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
// Check support
|
||||
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
|
||||
k->type != HTP_TYPE_F16 ||
|
||||
v->type != HTP_TYPE_F16) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
|
||||
octx->src0_div1 = init_fastdiv_values(q->ne[1]);
|
||||
|
||||
octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
|
||||
octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
|
||||
octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
|
||||
octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
|
||||
|
||||
if (mask) {
|
||||
octx->src3_div2 = init_fastdiv_values(mask->ne[2]);
|
||||
octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
|
||||
}
|
||||
|
||||
size_t size_q_row_padded = htp_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
|
||||
size_t size_k_row_padded = htp_round_up(k->ne[0] * sizeof(__fp16), 128);
|
||||
size_t size_v_row_padded = htp_round_up(v->ne[0] * sizeof(__fp16), 128);
|
||||
|
||||
size_t size_q_block = size_q_row_padded * 1; // single row for now
|
||||
size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
||||
|
||||
size_t size_vkq_acc = htp_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
|
||||
|
||||
octx->src0_spad.size_per_thread = size_q_block * 1;
|
||||
octx->src1_spad.size_per_thread = size_k_block * 2;
|
||||
octx->src2_spad.size_per_thread = size_v_block * 2;
|
||||
octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
|
||||
octx->dst_spad.size_per_thread = size_vkq_acc;
|
||||
|
||||
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
|
||||
octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
|
||||
octx->src2_spad.size = octx->src2_spad.size_per_thread * octx->n_threads;
|
||||
octx->src3_spad.size = octx->src3_spad.size_per_thread * octx->n_threads;
|
||||
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
|
||||
|
||||
size_t total_spad = octx->src0_spad.size + octx->src1_spad.size + octx->src2_spad.size + octx->src3_spad.size + octx->dst_spad.size;
|
||||
|
||||
if (octx->ctx->vtcm_size < total_spad) {
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;
|
||||
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads);
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
|
@ -0,0 +1,112 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
#define get_rows_preamble \
|
||||
const uint32_t ne00 = octx->src0.ne[0]; \
|
||||
const uint32_t ne01 = octx->src0.ne[1]; \
|
||||
const uint32_t ne02 = octx->src0.ne[2]; \
|
||||
const uint32_t ne03 = octx->src0.ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = octx->src1.ne[0]; \
|
||||
const uint32_t ne11 = octx->src1.ne[1]; \
|
||||
const uint32_t ne12 = octx->src1.ne[2]; \
|
||||
\
|
||||
const uint32_t nb01 = octx->src0.nb[1]; \
|
||||
const uint32_t nb02 = octx->src0.nb[2]; \
|
||||
const uint32_t nb03 = octx->src0.nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = octx->src1.nb[0]; \
|
||||
const uint32_t nb11 = octx->src1.nb[1]; \
|
||||
const uint32_t nb12 = octx->src1.nb[2]; \
|
||||
\
|
||||
const uint32_t nb1 = octx->dst.nb[1]; \
|
||||
const uint32_t nb2 = octx->dst.nb[2]; \
|
||||
const uint32_t nb3 = octx->dst.nb[3]; \
|
||||
\
|
||||
const uint32_t nr = ne10 * ne11 * ne12;
|
||||
|
||||
static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
get_rows_preamble;
|
||||
|
||||
// parallelize by src1 elements (which correspond to dst rows)
|
||||
const uint32_t dr = octx->src1_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
|
||||
|
||||
for (uint32_t i = ir0; i < ir1; ++i) {
|
||||
const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11);
|
||||
const uint32_t rem = i - i12 * ne11 * ne10;
|
||||
const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10);
|
||||
const uint32_t i10 = rem - i11 * ne10;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
|
||||
uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
|
||||
|
||||
if (i01 >= ne01) {
|
||||
// invalid index, skip for now to avoid crash
|
||||
continue;
|
||||
}
|
||||
|
||||
const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3;
|
||||
hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
|
||||
get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
int op_get_rows(struct htp_ops_context * octx) {
|
||||
get_rows_preamble;
|
||||
|
||||
if (octx->src0.type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->dst.type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
|
||||
octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
|
||||
|
||||
const uint32_t n_jobs = MIN(nr, octx->n_threads);
|
||||
octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs);
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
|
@ -11,11 +11,6 @@
|
|||
|
||||
#define HTP_MAX_NTHREADS 10
|
||||
|
||||
// FIXME: move these into matmul-ops
|
||||
#define HTP_SPAD_SRC0_NROWS 16
|
||||
#define HTP_SPAD_SRC1_NROWS 16
|
||||
#define HTP_SPAD_DST_NROWS 2
|
||||
|
||||
// Main context for htp DSP backend
|
||||
struct htp_context {
|
||||
dspqueue_t queue;
|
||||
|
|
|
|||
|
|
@ -36,6 +36,8 @@ enum htp_data_type {
|
|||
HTP_TYPE_F16 = 1,
|
||||
HTP_TYPE_Q4_0 = 2,
|
||||
HTP_TYPE_Q8_0 = 8,
|
||||
HTP_TYPE_I32 = 26,
|
||||
HTP_TYPE_I64 = 27,
|
||||
HTP_TYPE_MXFP4 = 39,
|
||||
HTP_TYPE_COUNT
|
||||
};
|
||||
|
|
@ -57,6 +59,10 @@ enum htp_op {
|
|||
HTP_OP_SOFTMAX = 11,
|
||||
HTP_OP_ADD_ID = 12,
|
||||
HTP_OP_ROPE = 13,
|
||||
HTP_OP_FLASH_ATTN_EXT = 14,
|
||||
HTP_OP_SET_ROWS = 15,
|
||||
HTP_OP_SCALE = 16,
|
||||
HTP_OP_GET_ROWS = 17,
|
||||
INVALID
|
||||
};
|
||||
|
||||
|
|
@ -137,6 +143,8 @@ struct htp_general_req {
|
|||
struct htp_tensor src0; // Input0 tensor
|
||||
struct htp_tensor src1; // Input1 tensor
|
||||
struct htp_tensor src2; // Input2 tensor
|
||||
struct htp_tensor src3; // Input3 tensor
|
||||
struct htp_tensor src4; // Input4 tensor
|
||||
struct htp_tensor dst; // Output tensor
|
||||
|
||||
// should be multiple of 64 bytes (cacheline)
|
||||
|
|
@ -152,6 +160,6 @@ struct htp_general_rsp {
|
|||
};
|
||||
|
||||
#define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req)
|
||||
#define HTP_MAX_PACKET_BUFFERS 4
|
||||
#define HTP_MAX_PACKET_BUFFERS 8
|
||||
|
||||
#endif /* HTP_MSG_H */
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
struct htp_spad {
|
||||
uint8_t * data;
|
||||
size_t stride;
|
||||
size_t size;
|
||||
size_t size_per_thread;
|
||||
};
|
||||
|
|
@ -26,11 +27,14 @@ struct htp_ops_context {
|
|||
struct htp_tensor src0;
|
||||
struct htp_tensor src1;
|
||||
struct htp_tensor src2;
|
||||
struct htp_tensor src3;
|
||||
struct htp_tensor src4;
|
||||
struct htp_tensor dst;
|
||||
|
||||
struct htp_spad src0_spad;
|
||||
struct htp_spad src1_spad;
|
||||
struct htp_spad src2_spad;
|
||||
struct htp_spad src3_spad;
|
||||
struct htp_spad dst_spad;
|
||||
|
||||
worker_pool_context_t * wpool; // worker pool
|
||||
|
|
@ -49,6 +53,27 @@ struct htp_ops_context {
|
|||
struct fastdiv_values src1_div3; // fastdiv values for ne3
|
||||
struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
|
||||
|
||||
struct fastdiv_values src3_div1; // fastdiv values for ne1
|
||||
struct fastdiv_values src3_div2; // fastdiv values for ne2
|
||||
struct fastdiv_values src3_div3; // fastdiv values for ne3
|
||||
struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1
|
||||
|
||||
struct fastdiv_values broadcast_rk2;
|
||||
struct fastdiv_values broadcast_rk3;
|
||||
struct fastdiv_values broadcast_rv2;
|
||||
struct fastdiv_values broadcast_rv3;
|
||||
|
||||
struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1
|
||||
struct fastdiv_values mm_div_ne1; // fastdiv values for ne1
|
||||
struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02
|
||||
struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03
|
||||
|
||||
struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
|
||||
struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
|
||||
|
||||
struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
|
||||
struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
|
||||
|
||||
uint32_t flags;
|
||||
};
|
||||
|
||||
|
|
@ -60,5 +85,8 @@ int op_activations(struct htp_ops_context * octx);
|
|||
int op_softmax(struct htp_ops_context * octx);
|
||||
int op_add_id(struct htp_ops_context * octx);
|
||||
int op_rope(struct htp_ops_context * octx);
|
||||
int op_flash_attn_ext(struct htp_ops_context * octx);
|
||||
int op_set_rows(struct htp_ops_context * octx);
|
||||
int op_get_rows(struct htp_ops_context * octx);
|
||||
|
||||
#endif /* HTP_OPS_H */
|
||||
|
|
|
|||
|
|
@ -848,55 +848,6 @@ float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems) {
|
|||
return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
|
||||
void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale) {
|
||||
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||
int num_elems_whole = num_elems - left_over;
|
||||
|
||||
int unaligned_addr = 0;
|
||||
int unaligned_loop = 0;
|
||||
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||
FARF(HIGH, "hvx_scale_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||
unaligned_addr = 1;
|
||||
}
|
||||
|
||||
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||
unaligned_loop = 1;
|
||||
FARF(HIGH, "hvx_scale_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
HVX_Vector scale_vec = hvx_vec_splat_fp32(scale);
|
||||
|
||||
if (0 == unaligned_loop) {
|
||||
HVX_Vector * vec_in1 = (HVX_Vector *) src;
|
||||
HVX_Vector * vec_out = (HVX_Vector *) dst;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, scale_vec);
|
||||
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||
|
||||
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec);
|
||||
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
|
||||
}
|
||||
}
|
||||
|
||||
if (left_over > 0) {
|
||||
const float * srcf = (const float *) src + num_elems_whole;
|
||||
float * dstf = (float *) dst + num_elems_whole;
|
||||
|
||||
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||
|
||||
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec);
|
||||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
|
||||
}
|
||||
}
|
||||
|
||||
float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) {
|
||||
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||
int num_elems_whole = num_elems - left_over;
|
||||
|
|
@ -1065,3 +1016,5 @@ void hvx_clamp_scalar_f32(const uint8_t * restrict src,
|
|||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, in_vec);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -41,15 +41,24 @@ static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in)
|
|||
}
|
||||
#endif
|
||||
|
||||
static inline HVX_Vector hvx_vec_splat_fp32(float i) {
|
||||
static inline HVX_Vector hvx_vec_splat_fp32(float v) {
|
||||
union {
|
||||
float f;
|
||||
int32_t i;
|
||||
} fp32 = { .f = i };
|
||||
float f;
|
||||
uint32_t i;
|
||||
} fp32 = { .f = v };
|
||||
|
||||
return Q6_V_vsplat_R(fp32.i);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_splat_fp16(float v) {
|
||||
union {
|
||||
__fp16 f;
|
||||
uint16_t i;
|
||||
} fp16 = { .f = v };
|
||||
|
||||
return Q6_Vh_vsplat_R(fp16.i);
|
||||
}
|
||||
|
||||
static inline void hvx_vec_store_u(void * addr, uint32_t n, HVX_Vector v) {
|
||||
// Rotate as needed.
|
||||
v = Q6_V_vlalign_VVR(v, v, (size_t) addr);
|
||||
|
|
@ -242,6 +251,120 @@ static inline void hvx_copy_fp32_au(uint8_t * restrict dst, const uint8_t * rest
|
|||
}
|
||||
}
|
||||
|
||||
// copy n fp32 elements : source is unaligned, destination unaligned
|
||||
static inline void hvx_copy_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
HVX_UVector * restrict vdst = (HVX_UVector *) dst;
|
||||
HVX_UVector * restrict vsrc = (HVX_UVector *) src;
|
||||
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
|
||||
uint32_t nvec = n / 32;
|
||||
uint32_t nloe = n % 32;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; i++) {
|
||||
HVX_Vector v = vsrc[i];
|
||||
vdst[i] = v;
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = vsrc[i];
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v);
|
||||
}
|
||||
}
|
||||
|
||||
// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is unaligned
|
||||
static inline void hvx_copy_fp16_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16
|
||||
HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t nvec = n / 64;
|
||||
uint32_t nloe = n % 64;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; i++) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
|
||||
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
|
||||
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
|
||||
vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
|
||||
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
|
||||
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
|
||||
}
|
||||
}
|
||||
|
||||
// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is unaligned
|
||||
static inline void hvx_copy_fp16_fp32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16
|
||||
HVX_Vector * restrict vsrc = (HVX_Vector *) src; // fp32
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t nvec = n / 64;
|
||||
uint32_t nloe = n % 64;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; i++) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
|
||||
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
|
||||
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
|
||||
vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
|
||||
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
|
||||
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
|
||||
}
|
||||
}
|
||||
|
||||
// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is aligned
|
||||
static inline void hvx_copy_fp16_fp32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
HVX_Vector * restrict vdst = (HVX_Vector *) dst; // fp16
|
||||
HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t nvec = n / 64;
|
||||
uint32_t nloe = n % 64;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; i++) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
|
||||
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
|
||||
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
|
||||
vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
|
||||
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
|
||||
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
|
||||
}
|
||||
}
|
||||
|
||||
// bcast 1 fp32 element from source to n fp32 elements in destination : destination is aligned
|
||||
static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t n) {
|
||||
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
|
||||
|
|
@ -273,8 +396,6 @@ static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint3
|
|||
return right_off <= chunk_size;
|
||||
}
|
||||
|
||||
|
||||
|
||||
static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) {
|
||||
HVX_VectorAlias u = { .v = v };
|
||||
|
||||
|
|
@ -531,13 +652,13 @@ static inline HVX_Vector hvx_vec_abs_fp32(HVX_Vector v) {
|
|||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) {
|
||||
#if __HTP_ARCH__ > 75
|
||||
#if __HVX_ARCH__ > 75
|
||||
return Q6_Vsf_vfneg_Vsf(v);
|
||||
#else
|
||||
// neg by setting the fp32 sign bit
|
||||
HVX_Vector mask = Q6_V_vsplat_R(0x80000000);
|
||||
return Q6_V_vxor_VV(v, mask);
|
||||
#endif // __HTP_ARCH__ > 75
|
||||
#endif // __HVX_ARCH__ > 75
|
||||
}
|
||||
|
||||
// ====================================================
|
||||
|
|
@ -976,6 +1097,24 @@ static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v,
|
|||
return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero());
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_tanh_fp32(HVX_Vector x) {
|
||||
// tanh(x) = 2 * sigmoid(2x) - 1
|
||||
HVX_Vector two = hvx_vec_splat_fp32(2.0f);
|
||||
HVX_Vector one = hvx_vec_splat_fp32(1.0f);
|
||||
HVX_Vector x2 = Q6_Vqf32_vmpy_VsfVsf(x, two);
|
||||
|
||||
static const float kMinExp = -87.f; // 0
|
||||
static const float kMaxExp = 87.f; // 1
|
||||
HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
|
||||
HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp);
|
||||
|
||||
HVX_Vector sig2x = hvx_vec_fast_sigmoid_fp32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp);
|
||||
|
||||
HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two);
|
||||
res = Q6_Vqf32_vsub_Vqf32Vsf(res, one);
|
||||
return Q6_Vsf_equals_Vqf32(res);
|
||||
}
|
||||
|
||||
static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
|
||||
int step_of_1 = num_elems >> 5;
|
||||
int remaining = num_elems - step_of_1 * VLEN_FP32;
|
||||
|
|
@ -1056,6 +1195,115 @@ static inline void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restr
|
|||
}
|
||||
}
|
||||
|
||||
static inline void hvx_scale_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
|
||||
int nvec = n / VLEN_FP32;
|
||||
int nloe = n % VLEN_FP32;
|
||||
|
||||
HVX_Vector vs = hvx_vec_splat_fp32(scale);
|
||||
|
||||
HVX_Vector * vsrc = (HVX_Vector *) src;
|
||||
HVX_Vector * vdst = (HVX_Vector *) dst;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; ++i) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
|
||||
vdst[i] = Q6_Vsf_equals_Vqf32(v);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_scale_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
|
||||
int nvec = n / VLEN_FP32;
|
||||
int nloe = n % VLEN_FP32;
|
||||
|
||||
HVX_Vector vs = hvx_vec_splat_fp32(scale);
|
||||
|
||||
HVX_UVector * vsrc = (HVX_UVector *) src;
|
||||
HVX_UVector * vdst = (HVX_UVector *) dst;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; ++i) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
|
||||
vdst[i] = Q6_Vsf_equals_Vqf32(v);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_scale_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
|
||||
if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) {
|
||||
hvx_scale_f32_aa(dst, src, n, scale);
|
||||
} else {
|
||||
hvx_scale_f32_uu(dst, src, n, scale);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_scale_offset_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
|
||||
int nvec = n / VLEN_FP32;
|
||||
int nloe = n % VLEN_FP32;
|
||||
|
||||
HVX_Vector vs = hvx_vec_splat_fp32(scale);
|
||||
HVX_Vector vo = hvx_vec_splat_fp32(offset);
|
||||
|
||||
HVX_Vector * vsrc = (HVX_Vector *) src;
|
||||
HVX_Vector * vdst = (HVX_Vector *) dst;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; ++i) {
|
||||
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
|
||||
vdst[i] = Q6_Vsf_equals_Vqf32(v);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_scale_offset_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
|
||||
int nvec = n / VLEN_FP32;
|
||||
int nloe = n % VLEN_FP32;
|
||||
|
||||
HVX_Vector vs = hvx_vec_splat_fp32(scale);
|
||||
HVX_Vector vo = hvx_vec_splat_fp32(offset);
|
||||
|
||||
HVX_UVector * vsrc = (HVX_UVector *) src;
|
||||
HVX_UVector * vdst = (HVX_UVector *) dst;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; ++i) {
|
||||
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
|
||||
vdst[i] = Q6_Vsf_equals_Vqf32(v);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_scale_offset_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
|
||||
if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) {
|
||||
hvx_scale_offset_f32_aa(dst, src, n, scale, offset);
|
||||
} else {
|
||||
hvx_scale_offset_f32_uu(dst, src, n, scale, offset);
|
||||
}
|
||||
}
|
||||
|
||||
float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems);
|
||||
void hvx_mul_f32(const uint8_t * restrict src0,
|
||||
|
|
@ -1090,7 +1338,6 @@ void hvx_sub_f32_opt(const uint8_t * restrict src0,
|
|||
uint8_t * restrict dst,
|
||||
const int num_elems);
|
||||
void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
|
||||
void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale);
|
||||
void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems);
|
||||
void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems);
|
||||
void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate);
|
||||
|
|
|
|||
|
|
@ -443,6 +443,45 @@ static void proc_matmul_req(struct htp_context * ctx,
|
|||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[0].fd = bufs[2].fd;
|
||||
rsp_bufs[0].ptr = bufs[2].ptr;
|
||||
rsp_bufs[0].offset = bufs[2].offset;
|
||||
rsp_bufs[0].size = bufs[2].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.src1 = req->src1;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.src1.data = (uint32_t) bufs[1].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[2].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_get_rows(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_matmul_id_req(struct htp_context * ctx,
|
||||
struct htp_general_req * req,
|
||||
struct dspqueue_buffer * bufs,
|
||||
|
|
@ -668,7 +707,7 @@ static void proc_rope_req(struct htp_context * ctx,
|
|||
uint32_t n_bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
|
||||
int write_idx = (n_bufs == 4) ? 3 : 2;
|
||||
int write_idx = n_bufs - 1;
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[0].fd = bufs[write_idx].fd;
|
||||
|
|
@ -716,6 +755,102 @@ static void proc_rope_req(struct htp_context * ctx,
|
|||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_set_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[0].fd = bufs[2].fd;
|
||||
rsp_bufs[0].ptr = bufs[2].ptr;
|
||||
rsp_bufs[0].offset = bufs[2].offset;
|
||||
rsp_bufs[0].size = bufs[2].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.src1 = req->src1;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.src1.data = (uint32_t) bufs[1].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[2].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_set_rows(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_flash_attn_ext_req(struct htp_context * ctx,
|
||||
struct htp_general_req * req,
|
||||
struct dspqueue_buffer * bufs,
|
||||
uint32_t n_bufs) {
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx;
|
||||
memset(&octx, 0, sizeof(octx));
|
||||
|
||||
octx.ctx = ctx;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
octx.src0 = req->src0;
|
||||
octx.src1 = req->src1;
|
||||
octx.src2 = req->src2;
|
||||
octx.src3 = req->src3;
|
||||
octx.src4 = req->src4;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.src1.data = (uint32_t) bufs[1].ptr;
|
||||
octx.src2.data = (uint32_t) bufs[2].ptr;
|
||||
|
||||
int last_buf = 3;
|
||||
|
||||
if (octx.src3.ne[0]) {
|
||||
octx.src3.data = (uint32_t) bufs[last_buf++].ptr; // mask is valid
|
||||
}
|
||||
|
||||
if (octx.src4.ne[0]) {
|
||||
octx.src4.data = (uint32_t) bufs[last_buf++].ptr; // sinks is valid
|
||||
}
|
||||
|
||||
octx.dst.data = (uint32_t) bufs[last_buf].ptr;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_flash_attn_ext(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
|
||||
struct dspqueue_buffer rsp_buf = bufs[last_buf];
|
||||
rsp_buf.flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof);
|
||||
}
|
||||
|
||||
static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
struct htp_context * ctx = (struct htp_context *) context;
|
||||
|
||||
|
|
@ -790,6 +925,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
|||
break;
|
||||
|
||||
case HTP_OP_RMS_NORM:
|
||||
case HTP_OP_SCALE:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad unary-req buffer list");
|
||||
continue;
|
||||
|
|
@ -833,6 +969,30 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
|||
proc_rope_req(ctx, &req, bufs, n_bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_FLASH_ATTN_EXT:
|
||||
if (!(n_bufs >= 4 && n_bufs <= 6)) {
|
||||
FARF(ERROR, "Bad flash-attn-ext-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_flash_attn_ext_req(ctx, &req, bufs, n_bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_SET_ROWS:
|
||||
if (n_bufs != 3) {
|
||||
FARF(ERROR, "Bad set-rows-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_set_rows_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_GET_ROWS:
|
||||
if (n_bufs != 3) {
|
||||
FARF(ERROR, "Bad get-rows-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_get_rows_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unknown Op %u", req.op);
|
||||
break;
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,168 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
#define set_rows_preamble \
|
||||
const uint32_t ne00 = octx->src0.ne[0]; \
|
||||
const uint32_t ne01 = octx->src0.ne[1]; \
|
||||
const uint32_t ne02 = octx->src0.ne[2]; \
|
||||
const uint32_t ne03 = octx->src0.ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = octx->src1.ne[0]; \
|
||||
const uint32_t ne11 = octx->src1.ne[1]; \
|
||||
const uint32_t ne12 = octx->src1.ne[2]; \
|
||||
\
|
||||
const uint32_t nb01 = octx->src0.nb[1]; \
|
||||
const uint32_t nb02 = octx->src0.nb[2]; \
|
||||
const uint32_t nb03 = octx->src0.nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = octx->src1.nb[0]; \
|
||||
const uint32_t nb11 = octx->src1.nb[1]; \
|
||||
const uint32_t nb12 = octx->src1.nb[2]; \
|
||||
\
|
||||
const uint32_t nb1 = octx->dst.nb[1]; \
|
||||
const uint32_t nb2 = octx->dst.nb[2]; \
|
||||
const uint32_t nb3 = octx->dst.nb[3]; \
|
||||
\
|
||||
const uint32_t ne1 = octx->dst.ne[1]; \
|
||||
\
|
||||
const uint32_t nr = ne01;
|
||||
|
||||
static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
set_rows_preamble;
|
||||
|
||||
// parallelize by rows of src0
|
||||
const uint32_t dr = octx->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
|
||||
|
||||
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
|
||||
for (uint32_t i = ir0; i < ir1; ++i) {
|
||||
const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
|
||||
const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
|
||||
const uint32_t i10 = i;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
|
||||
uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
|
||||
if (i1 >= ne1) {
|
||||
// ignore invalid indices
|
||||
continue;
|
||||
}
|
||||
|
||||
const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
|
||||
const uintptr_t dst_ptr = octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
|
||||
|
||||
// copy row
|
||||
hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
set_rows_preamble;
|
||||
|
||||
// parallelize by rows of src0
|
||||
const uint32_t dr = octx->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
|
||||
|
||||
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
|
||||
for (uint32_t i = ir0; i < ir1; ++i) {
|
||||
const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
|
||||
const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
|
||||
const uint32_t i10 = i;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
|
||||
uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
|
||||
if (i1 >= ne1) {
|
||||
// ignore invalid indices
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
|
||||
uint8_t* dst_ptr = (uint8_t *) octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
|
||||
|
||||
hvx_copy_fp16_fp32_uu(dst_ptr, src0_ptr, ne00);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) {
|
||||
set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
|
||||
set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
int op_set_rows(struct htp_ops_context * octx) {
|
||||
set_rows_preamble;
|
||||
|
||||
if (octx->src0.type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
octx->set_rows_div_ne12 = init_fastdiv_values(ne12);
|
||||
octx->set_rows_div_ne11 = init_fastdiv_values(ne11);
|
||||
|
||||
const uint32_t n_jobs = MIN(nr, octx->n_threads);
|
||||
octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
|
||||
|
||||
switch(octx->dst.type) {
|
||||
case HTP_TYPE_F32:
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs);
|
||||
break;
|
||||
case HTP_TYPE_F16:
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs);
|
||||
break;
|
||||
default:
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
|
@ -238,7 +238,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct
|
|||
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale,
|
||||
(const uint8_t *) mp_f32, slope);
|
||||
} else {
|
||||
hvx_scale_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale);
|
||||
hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale);
|
||||
if (mp_f32) {
|
||||
if (softmax_ctx->use_f16) {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
|
|
@ -258,7 +258,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct
|
|||
float max = hvx_self_max_f32((const uint8_t *) wp0, ne00);
|
||||
float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
|
||||
sum = sum > 0.0 ? (1.0 / sum) : 1;
|
||||
hvx_scale_f32((const uint8_t *) wp2, (uint8_t *) dp, ne00, sum);
|
||||
hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -83,6 +83,31 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
|
|||
}
|
||||
}
|
||||
|
||||
static void scale_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
int opt_path) {
|
||||
float scale = 0.f;
|
||||
float bias = 0.f;
|
||||
memcpy(&scale, &op_params[0], sizeof(float));
|
||||
memcpy(&bias, &op_params[1], sizeof(float));
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_local = src + (ir * row_elems);
|
||||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
htp_l2fetch(src_local + row_elems, 1, row_size, row_size);
|
||||
}
|
||||
|
||||
hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
|
||||
}
|
||||
}
|
||||
|
||||
static void rms_norm_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
|
|
@ -110,7 +135,7 @@ static void rms_norm_htp_f32(const float * restrict src,
|
|||
const float mean = sum / row_elems;
|
||||
const float scale = 1.0f / sqrtf(mean + epsilon);
|
||||
|
||||
hvx_scale_f32((const uint8_t *) src_local, (uint8_t *) dst_local, row_elems, scale);
|
||||
hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -162,6 +187,9 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
|
|||
case HTP_OP_RMS_NORM:
|
||||
rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
case HTP_OP_SCALE:
|
||||
scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
|
|
@ -195,6 +223,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
|||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "rmsnorm-f32";
|
||||
break;
|
||||
case HTP_OP_SCALE:
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "scale-f32";
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
|
||||
|
|
|
|||
|
|
@ -550,6 +550,8 @@ struct vk_device_struct {
|
|||
uint64_t max_memory_allocation_size;
|
||||
uint64_t max_buffer_size;
|
||||
uint64_t suballocation_block_size;
|
||||
uint64_t min_imported_host_pointer_alignment;
|
||||
bool external_memory_host {};
|
||||
bool fp16;
|
||||
bool bf16;
|
||||
bool pipeline_robustness;
|
||||
|
|
@ -2410,7 +2412,8 @@ static std::vector<uint32_t> ggml_vk_find_memory_properties(const vk::PhysicalDe
|
|||
return indices;
|
||||
}
|
||||
|
||||
static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list<vk::MemoryPropertyFlags> & req_flags_list) {
|
||||
static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list<vk::MemoryPropertyFlags> & req_flags_list,
|
||||
void *import_ptr = nullptr) {
|
||||
VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")");
|
||||
if (size > device->max_buffer_size) {
|
||||
throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit");
|
||||
|
|
@ -2439,6 +2442,12 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
|||
nullptr,
|
||||
};
|
||||
|
||||
vk::ExternalMemoryBufferCreateInfo external_memory_bci;
|
||||
if (import_ptr) {
|
||||
external_memory_bci.handleTypes = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT;
|
||||
buffer_create_info.setPNext(&external_memory_bci);
|
||||
}
|
||||
|
||||
buf->buffer = device->device.createBuffer(buffer_create_info);
|
||||
|
||||
vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer);
|
||||
|
|
@ -2453,35 +2462,80 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
|||
mem_flags_info.setPNext(&mem_priority_info);
|
||||
}
|
||||
|
||||
for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
|
||||
const auto & req_flags = *it;
|
||||
|
||||
const std::vector<uint32_t> memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags);
|
||||
|
||||
if (memory_type_indices.empty()) {
|
||||
continue;
|
||||
if (import_ptr) {
|
||||
vk::MemoryHostPointerPropertiesEXT host_pointer_props;
|
||||
try {
|
||||
host_pointer_props = device->device.getMemoryHostPointerPropertiesEXT(vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT, import_ptr);
|
||||
} catch (vk::SystemError& e) {
|
||||
GGML_LOG_WARN("ggml_vulkan: Failed getMemoryHostPointerPropertiesEXT (%s)\n", e.what());
|
||||
device->device.destroyBuffer(buf->buffer);
|
||||
return {};
|
||||
}
|
||||
buf->memory_property_flags = req_flags;
|
||||
vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
|
||||
|
||||
bool done = false;
|
||||
uint32_t memory_type_idx;
|
||||
vk::MemoryPropertyFlags property_flags = *req_flags_list.begin();
|
||||
for (memory_type_idx = 0; memory_type_idx < 32; ++memory_type_idx) {
|
||||
if (!(host_pointer_props.memoryTypeBits & (1u << memory_type_idx))) {
|
||||
continue;
|
||||
}
|
||||
if (!(mem_req.memoryTypeBits & (1u << memory_type_idx))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) {
|
||||
try {
|
||||
buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info });
|
||||
done = true;
|
||||
vk::MemoryType memory_type = mem_props.memoryTypes[memory_type_idx];
|
||||
// check for visible+coherent+cached. Other flags (e.g. devicelocal) are allowed
|
||||
if ((memory_type.propertyFlags & property_flags) == property_flags) {
|
||||
property_flags = memory_type.propertyFlags;
|
||||
break;
|
||||
} catch (const vk::SystemError& e) {
|
||||
// loop and retry
|
||||
// during last attempt throw the exception
|
||||
if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) {
|
||||
device->device.destroyBuffer(buf->buffer);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (memory_type_idx == 32) {
|
||||
GGML_LOG_WARN("ggml_vulkan: Memory type for host allocation not found\n");
|
||||
device->device.destroyBuffer(buf->buffer);
|
||||
return {};
|
||||
}
|
||||
|
||||
if (done) {
|
||||
break;
|
||||
buf->memory_property_flags = mem_props.memoryTypes[memory_type_idx].propertyFlags;
|
||||
try {
|
||||
vk::ImportMemoryHostPointerInfoEXT import_info;
|
||||
import_info.handleType = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT;
|
||||
import_info.pHostPointer = import_ptr;
|
||||
import_info.setPNext(&mem_flags_info);
|
||||
buf->device_memory = device->device.allocateMemory({ size, memory_type_idx, &import_info });
|
||||
} catch (const vk::SystemError& e) {
|
||||
}
|
||||
} else {
|
||||
for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
|
||||
const auto & req_flags = *it;
|
||||
|
||||
const std::vector<uint32_t> memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags);
|
||||
|
||||
if (memory_type_indices.empty()) {
|
||||
continue;
|
||||
}
|
||||
buf->memory_property_flags = req_flags;
|
||||
|
||||
bool done = false;
|
||||
|
||||
for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) {
|
||||
try {
|
||||
buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info });
|
||||
done = true;
|
||||
break;
|
||||
} catch (const vk::SystemError& e) {
|
||||
// loop and retry
|
||||
// during last attempt throw the exception
|
||||
if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) {
|
||||
device->device.destroyBuffer(buf->buffer);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2492,8 +2546,12 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
|||
|
||||
buf->ptr = nullptr;
|
||||
|
||||
if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
|
||||
buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
|
||||
if (import_ptr) {
|
||||
buf->ptr = import_ptr;
|
||||
} else {
|
||||
if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
|
||||
buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
|
||||
}
|
||||
}
|
||||
|
||||
device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
|
||||
|
|
@ -4447,6 +4505,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
} else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 &&
|
||||
getenv("GGML_VK_ENABLE_MEMORY_PRIORITY")) {
|
||||
device->memory_priority = true;
|
||||
} else if (strcmp("VK_EXT_external_memory_host", properties.extensionName) == 0) {
|
||||
device->external_memory_host = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -4461,6 +4521,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
vk::PhysicalDeviceVulkan12Properties vk12_props;
|
||||
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
|
||||
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
|
||||
vk::PhysicalDeviceExternalMemoryHostPropertiesEXT external_memory_host_props;
|
||||
|
||||
props2.pNext = &props3;
|
||||
props3.pNext = &subgroup_props;
|
||||
|
|
@ -4500,11 +4561,22 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
|
||||
}
|
||||
|
||||
if (device->external_memory_host) {
|
||||
last_struct->pNext = (VkBaseOutStructure *)&external_memory_host_props;
|
||||
last_struct = (VkBaseOutStructure *)&external_memory_host_props;
|
||||
}
|
||||
|
||||
device->physical_device.getProperties2(&props2);
|
||||
device->properties = props2.properties;
|
||||
device->vendor_id = device->properties.vendorID;
|
||||
device->driver_id = driver_props.driverID;
|
||||
|
||||
if (device->driver_id == vk::DriverId::eMoltenvk) {
|
||||
// Disable external_memory_host until https://github.com/KhronosGroup/MoltenVK/pull/2622
|
||||
// is available in the Vulkan SDK.
|
||||
device->external_memory_host = false;
|
||||
}
|
||||
|
||||
// Implementing the async backend interfaces seems broken on older Intel HW,
|
||||
// see https://github.com/ggml-org/llama.cpp/issues/17302.
|
||||
device->support_async = (device->vendor_id != VK_VENDOR_ID_INTEL ||
|
||||
|
|
@ -4586,6 +4658,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
|
||||
device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
|
||||
|
||||
device->min_imported_host_pointer_alignment = external_memory_host_props.minImportedHostPointerAlignment;
|
||||
|
||||
device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations)));
|
||||
|
||||
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
|
||||
|
|
@ -4717,6 +4791,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
device_extensions.push_back("VK_KHR_pipeline_executable_properties");
|
||||
}
|
||||
|
||||
if (device->external_memory_host) {
|
||||
device_extensions.push_back("VK_EXT_external_memory_host");
|
||||
}
|
||||
|
||||
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
|
||||
|
||||
device->pipeline_executable_properties_support = pipeline_executable_properties_support;
|
||||
|
|
@ -14773,6 +14851,51 @@ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggm
|
|||
VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize");
|
||||
}
|
||||
|
||||
static vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size) {
|
||||
if (!device->external_memory_host) {
|
||||
return {};
|
||||
}
|
||||
|
||||
uintptr_t uptr = reinterpret_cast<uintptr_t>(ptr);
|
||||
if (uptr & (device->min_imported_host_pointer_alignment - 1)) {
|
||||
return {};
|
||||
}
|
||||
if (size & (device->min_imported_host_pointer_alignment - 1)) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const vk::MemoryPropertyFlags property_flags = vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached;
|
||||
|
||||
vk_buffer buf {};
|
||||
try {
|
||||
buf = ggml_vk_create_buffer(device, size, { property_flags }, ptr);
|
||||
} catch (vk::SystemError& e) {
|
||||
GGML_LOG_WARN("ggml_vulkan: Failed ggml_vk_create_buffer (%s)\n", e.what());
|
||||
}
|
||||
|
||||
return buf;
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_vk_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
|
||||
VK_LOG_DEBUG("ggml_backend_vk_device_buffer_from_host_ptr(backend=" << dev << ", ptr=" << ptr << ", size=" << size << ")");
|
||||
GGML_UNUSED(max_tensor_size);
|
||||
|
||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||
auto device = ggml_vk_get_device(ctx->device);
|
||||
|
||||
vk_buffer buf = ggml_vk_buffer_from_host_ptr(device, ptr, size);
|
||||
|
||||
if (!buf) {
|
||||
return {};
|
||||
}
|
||||
|
||||
ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(device, std::move(buf), device->name);
|
||||
|
||||
ggml_backend_buffer_t ret = ggml_backend_buffer_init(ggml_backend_vk_device_get_buffer_type(dev), ggml_backend_vk_buffer_interface, bufctx, size);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
|
||||
/* .get_name = */ ggml_backend_vk_device_get_name,
|
||||
/* .get_description = */ ggml_backend_vk_device_get_description,
|
||||
|
|
@ -14782,7 +14905,7 @@ static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
|
|||
/* .init_backend = */ ggml_backend_vk_device_init,
|
||||
/* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type,
|
||||
/* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
|
||||
/* .buffer_from_host_ptr = */ NULL,
|
||||
/* .buffer_from_host_ptr = */ ggml_backend_vk_device_buffer_from_host_ptr,
|
||||
/* .supports_op = */ ggml_backend_vk_device_supports_op,
|
||||
/* .supports_buft = */ ggml_backend_vk_device_supports_buft,
|
||||
/* .offload_op = */ ggml_backend_vk_device_offload_op,
|
||||
|
|
|
|||
|
|
@ -16,8 +16,14 @@ model="Llama-3.2-3B-Instruct-Q4_0.gguf"
|
|||
device="HTP0"
|
||||
[ "$D" != "" ] && device="$D"
|
||||
|
||||
verbose=""
|
||||
[ "$V" != "" ] && verbose="$V"
|
||||
verbose=
|
||||
[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" cli_opts="$cli_opts -v"
|
||||
|
||||
experimental=
|
||||
[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E"
|
||||
|
||||
profile=
|
||||
[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" cli_opts="$cli_opts -v"
|
||||
|
||||
opmask=
|
||||
[ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK"
|
||||
|
|
@ -34,7 +40,7 @@ adb $adbserial shell " \
|
|||
cd $basedir; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$ndev $nhvx $opmask ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \
|
||||
$ndev $nhvx $opmask $verbose $experimental $profile ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--batch-size 128 -ngl 99 $@ \
|
||||
--batch-size 128 -ngl 99 $cli_opts $@ \
|
||||
"
|
||||
|
|
|
|||
|
|
@ -359,6 +359,11 @@ static void llama_params_fit_impl(
|
|||
|
||||
// for the first partial layer varying parts can overflow, all further layers use LAYER_FRACTION_MOE:
|
||||
layer_fraction_t overflow_type = LAYER_FRACTION_MOE;
|
||||
|
||||
uint32_t n_full() const {
|
||||
assert(n_layer >= n_part);
|
||||
return n_layer - n_part;
|
||||
}
|
||||
};
|
||||
|
||||
const size_t ntbo = llama_max_tensor_buft_overrides();
|
||||
|
|
@ -382,7 +387,7 @@ static void llama_params_fit_impl(
|
|||
|
||||
size_t itbo = 0;
|
||||
for (size_t id = 0; id < nd; id++) {
|
||||
il0 += ngl_per_device[id].n_layer - ngl_per_device[id].n_part;
|
||||
il0 += ngl_per_device[id].n_full();
|
||||
for (uint32_t il = il0; il < il0 + ngl_per_device[id].n_part; il++) {
|
||||
if (itbo + 1 >= ntbo) {
|
||||
tensor_buft_overrides[itbo].pattern = nullptr;
|
||||
|
|
@ -393,7 +398,7 @@ static void llama_params_fit_impl(
|
|||
+ std::to_string(ntbo) + " is insufficient for model");
|
||||
}
|
||||
tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE);
|
||||
tensor_buft_overrides[itbo].buft = overflow_bufts[id];
|
||||
tensor_buft_overrides[itbo].buft = il == il0 ? overflow_bufts[id] : ggml_backend_cpu_buffer_type();
|
||||
itbo++;
|
||||
}
|
||||
il0 += ngl_per_device[id].n_part;
|
||||
|
|
@ -468,20 +473,14 @@ static void llama_params_fit_impl(
|
|||
LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB);
|
||||
}
|
||||
|
||||
std::vector<ggml_backend_buffer_type_t> overflow_bufts; // which bufts the partial layers of a device overflow to:
|
||||
std::vector<ggml_backend_buffer_type_t> overflow_bufts; // which bufts the first partial layer of a device overflows to:
|
||||
overflow_bufts.reserve(nd);
|
||||
for (size_t id = 0; id < nd - 1; ++id) {
|
||||
overflow_bufts.push_back(ggml_backend_dev_buffer_type(devs[id + 1]));
|
||||
for (size_t id = 0; id < nd; id++) {
|
||||
overflow_bufts.push_back(ggml_backend_cpu_buffer_type());
|
||||
}
|
||||
overflow_bufts.push_back(ggml_backend_cpu_buffer_type());
|
||||
|
||||
std::vector<ngl_t> ngl_per_device(nd);
|
||||
std::vector<int64_t> mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts);
|
||||
if (hp_nex > 0) {
|
||||
for (size_t id = 0; id < nd; id++) {
|
||||
ngl_per_device[id].overflow_type = LAYER_FRACTION_MOE;
|
||||
}
|
||||
}
|
||||
|
||||
// optimize the number of layers per device using the method of false position:
|
||||
// - ngl_per_device has 0 layers for each device, lower bound
|
||||
|
|
@ -512,9 +511,6 @@ static void llama_params_fit_impl(
|
|||
if (mem_high[id] > targets[id]) {
|
||||
assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer);
|
||||
uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
|
||||
if (hp_nex > 0 && size_t(id) == nd - 1) {
|
||||
delta--;
|
||||
}
|
||||
LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta);
|
||||
while (delta > 1) {
|
||||
uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);
|
||||
|
|
@ -524,7 +520,8 @@ static void llama_params_fit_impl(
|
|||
std::vector<ngl_t> ngl_per_device_test = ngl_per_device;
|
||||
ngl_per_device_test[id].n_layer += step_size;
|
||||
if (hp_nex) {
|
||||
ngl_per_device_test[id].n_part += step_size;
|
||||
ngl_per_device_test[id].n_part += size_t(id) == nd - 1 && ngl_per_device_test[id].n_part == 0 ?
|
||||
step_size - 1 : step_size; // the first layer is the output layer which must always be full
|
||||
}
|
||||
const std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
|
||||
|
||||
|
|
@ -573,7 +570,7 @@ static void llama_params_fit_impl(
|
|||
assert(id_dense_start < nd);
|
||||
|
||||
LLAMA_LOG_INFO("%s: converting dense-only layers to full layers and filling them front-to-back with overflow to next device/system memory:\n", __func__);
|
||||
for (size_t id = 0; id <= id_dense_start; id++) {
|
||||
for (size_t id = 0; id <= id_dense_start && id_dense_start < nd; id++) {
|
||||
std::vector<ngl_t> ngl_per_device_high = ngl_per_device;
|
||||
for (size_t jd = id_dense_start; jd < nd; jd++) {
|
||||
const uint32_t n_layer_move = jd < nd - 1 ? ngl_per_device_high[jd].n_layer : ngl_per_device_high[jd].n_layer - 1;
|
||||
|
|
@ -585,12 +582,8 @@ static void llama_params_fit_impl(
|
|||
std::vector<int64_t> mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts);
|
||||
|
||||
if (mem_high[id] > targets[id]) {
|
||||
assert(ngl_per_device_high[id].n_layer >= ngl_per_device_high[id].n_part);
|
||||
assert(ngl_per_device[id].n_layer >= ngl_per_device[id].n_part);
|
||||
assert((ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part)
|
||||
>= ngl_per_device[id].n_layer - ngl_per_device[id].n_part);
|
||||
uint32_t delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part)
|
||||
- (ngl_per_device[id].n_layer - ngl_per_device[id].n_part);
|
||||
assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full());
|
||||
uint32_t delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full();
|
||||
while (delta > 1) {
|
||||
uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);
|
||||
step_size = std::max(step_size, uint32_t(1));
|
||||
|
|
@ -606,7 +599,7 @@ static void llama_params_fit_impl(
|
|||
ngl_per_device_test[id].n_layer += n_convert_jd;
|
||||
n_converted_test += n_convert_jd;
|
||||
|
||||
if (ngl_per_device_test[id_dense_start_test].n_layer > 0) {
|
||||
if (ngl_per_device_test[id_dense_start_test].n_part > 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -625,8 +618,8 @@ static void llama_params_fit_impl(
|
|||
LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start_high=%zu\n",
|
||||
__func__, id, ngl_per_device_high[id].n_layer, ngl_per_device_high[id].n_part, id_dense_start_high);
|
||||
}
|
||||
delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part)
|
||||
- (ngl_per_device[id].n_layer - ngl_per_device[id].n_part);
|
||||
assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full());
|
||||
delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full();
|
||||
}
|
||||
} else {
|
||||
ngl_per_device = ngl_per_device_high;
|
||||
|
|
@ -644,14 +637,19 @@ static void llama_params_fit_impl(
|
|||
ngl_per_device_test[id_dense_start_test].n_part--;
|
||||
ngl_per_device_test[id].n_layer++;
|
||||
ngl_per_device_test[id].n_part++;
|
||||
if (ngl_per_device_test[id_dense_start_test].n_layer == 0) {
|
||||
if (ngl_per_device_test[id_dense_start_test].n_part == 0) {
|
||||
id_dense_start_test++;
|
||||
}
|
||||
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP;
|
||||
std::vector<ggml_backend_buffer_type_t> overflow_bufts_test = overflow_bufts;
|
||||
if (id < nd - 1) {
|
||||
overflow_bufts_test[id] = ggml_backend_dev_buffer_type(devs[id + 1]);
|
||||
}
|
||||
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__);
|
||||
std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
|
||||
std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test);
|
||||
if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
|
||||
ngl_per_device = ngl_per_device_test;
|
||||
overflow_bufts = overflow_bufts_test;
|
||||
mem = mem_test;
|
||||
id_dense_start = id_dense_start_test;
|
||||
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", UP), id_dense_start=%zu\n",
|
||||
|
|
@ -659,9 +657,10 @@ static void llama_params_fit_impl(
|
|||
|
||||
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE;
|
||||
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__);
|
||||
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
|
||||
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test);
|
||||
if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
|
||||
ngl_per_device = ngl_per_device_test;
|
||||
overflow_bufts = overflow_bufts_test;
|
||||
mem = mem_test;
|
||||
id_dense_start = id_dense_start_test;
|
||||
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", GATE), id_dense_start=%zu\n",
|
||||
|
|
@ -670,9 +669,10 @@ static void llama_params_fit_impl(
|
|||
} else {
|
||||
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN;
|
||||
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__);
|
||||
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
|
||||
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test);
|
||||
if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
|
||||
ngl_per_device = ngl_per_device_test;
|
||||
overflow_bufts = overflow_bufts_test;
|
||||
mem = mem_test;
|
||||
id_dense_start = id_dense_start_test;
|
||||
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", ATTN), id_dense_start=%zu\n",
|
||||
|
|
@ -687,6 +687,14 @@ static void llama_params_fit_impl(
|
|||
__func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB);
|
||||
}
|
||||
|
||||
// print info for devices that were not changed during the conversion from dense only to full layers:
|
||||
for (size_t id = id_dense_start + 1; id < nd; id++) {
|
||||
const int64_t projected_margin = dmds_full[id].free - mem[id];
|
||||
LLAMA_LOG_INFO(
|
||||
"%s: - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n",
|
||||
__func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB);
|
||||
}
|
||||
|
||||
set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -127,6 +127,15 @@ int main(void) {
|
|||
assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_SPECULATIVE));
|
||||
assert(params.speculative.n_max == 123);
|
||||
|
||||
// multi-value args (CSV)
|
||||
argv = {"binary_name", "--lora", "file1.gguf,\"file2,2.gguf\",\"file3\"\"3\"\".gguf\",file4\".gguf"};
|
||||
assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
|
||||
assert(params.lora_adapters.size() == 4);
|
||||
assert(params.lora_adapters[0].path == "file1.gguf");
|
||||
assert(params.lora_adapters[1].path == "file2,2.gguf");
|
||||
assert(params.lora_adapters[2].path == "file3\"3\".gguf");
|
||||
assert(params.lora_adapters[3].path == "file4\".gguf");
|
||||
|
||||
// skip this part on windows, because setenv is not supported
|
||||
#ifdef _WIN32
|
||||
printf("test-arg-parser: skip on windows build\n");
|
||||
|
|
|
|||
|
|
@ -9,207 +9,250 @@
|
|||
#include <fstream>
|
||||
#include <algorithm>
|
||||
|
||||
// most of the code here is copied from whisper.cpp
|
||||
// some of the code here is copied from whisper.cpp
|
||||
|
||||
constexpr bool DEBUG = false;
|
||||
|
||||
struct mtmd_audio_mel_filters {
|
||||
int32_t n_mel;
|
||||
int32_t n_fft;
|
||||
void mtmd_audio_cache::fill_sin_cos_table(int n) {
|
||||
sin_vals.resize(n);
|
||||
cos_vals.resize(n);
|
||||
for (int i = 0; i < n; i++) {
|
||||
double theta = (2 * M_PI * i) / n;
|
||||
sin_vals[i] = sinf(theta);
|
||||
cos_vals[i] = cosf(theta);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float> data;
|
||||
};
|
||||
void mtmd_audio_cache::fill_hann_window(int length, bool periodic) {
|
||||
hann_window.resize(length);
|
||||
int offset = -1;
|
||||
if (periodic) {
|
||||
offset = 0;
|
||||
}
|
||||
for (int i = 0; i < length; i++) {
|
||||
hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
|
||||
}
|
||||
}
|
||||
|
||||
// note: this global cache is shared among all preprocessors
|
||||
// if we want to use multiple preprocessors at the same time,
|
||||
// we will need to enclose it in the preprocessor class in the future
|
||||
static struct mtmd_audio_global_cache {
|
||||
// precomputed sin/cos table for FFT
|
||||
std::vector<float> sin_vals;
|
||||
std::vector<float> cos_vals;
|
||||
|
||||
// hann window
|
||||
std::vector<float> hann_window;
|
||||
|
||||
// mel filter bank
|
||||
mtmd_audio_mel_filters filters;
|
||||
|
||||
void fill_sin_cos_table(int n) {
|
||||
sin_vals.resize(n);
|
||||
cos_vals.resize(n);
|
||||
for (int i = 0; i < n; i++) {
|
||||
double theta = (2 * M_PI * i) / n;
|
||||
sin_vals[i] = sinf(theta);
|
||||
cos_vals[i] = cosf(theta);
|
||||
}
|
||||
void mtmd_audio_cache::fill_mel_filterbank_matrix(int n_mel,
|
||||
int n_fft,
|
||||
int sample_rate,
|
||||
float fmin,
|
||||
float fmax,
|
||||
bool slaney_area_norm,
|
||||
float scale) {
|
||||
GGML_ASSERT(n_mel > 0 && n_fft > 1);
|
||||
if (fmax <= 0.0f) {
|
||||
fmax = 0.5f * sample_rate;
|
||||
}
|
||||
|
||||
void fill_hann_window(int length, bool periodic) {
|
||||
hann_window.resize(length);
|
||||
int offset = -1;
|
||||
if (periodic) {
|
||||
offset = 0;
|
||||
}
|
||||
for (int i = 0; i < length; i++) {
|
||||
hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
|
||||
}
|
||||
// Slaney scale (matches librosa default)
|
||||
const double min_log_hz = 1000.0;
|
||||
const double lin_slope = 3 / 200.;
|
||||
const double min_log_mel = min_log_hz * lin_slope;
|
||||
const double log_step = log(6.4) / 27.0;
|
||||
auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double {
|
||||
return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step;
|
||||
};
|
||||
auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double {
|
||||
return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step);
|
||||
};
|
||||
|
||||
// infer N_fft from n_fft_bins
|
||||
const double bin_hz_step = double(sample_rate) / double(n_fft);
|
||||
|
||||
// mel grid: n_mel + 2 edges
|
||||
const double m_lo = hz_to_mel(fmin);
|
||||
const double m_hi = hz_to_mel(fmax);
|
||||
std::vector<double> mel_pts(n_mel + 2);
|
||||
for (int i = 0; i < n_mel + 2; ++i) {
|
||||
mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1));
|
||||
}
|
||||
|
||||
// Build mel filterbank matrix [n_mel × n_fft_bins] at runtime.
|
||||
// n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257.
|
||||
void fill_mel_filterbank_matrix(
|
||||
int n_mel,
|
||||
int n_fft,
|
||||
int sample_rate, // e.g. 16000
|
||||
float fmin = 0.0f, // e.g. 0.0
|
||||
float fmax = -1.0f, // e.g. sr/2; pass -1 for auto
|
||||
bool slaney_area_norm = true,
|
||||
float scale = 1.0f // optional extra scaling; use 1.0f/1000.0f to mimic your code
|
||||
) {
|
||||
GGML_ASSERT(n_mel > 0 && n_fft > 1);
|
||||
if (fmax <= 0.0f) {
|
||||
fmax = 0.5f * sample_rate;
|
||||
}
|
||||
// convert to Hz
|
||||
std::vector<double> hz_pts(n_mel + 2);
|
||||
for (int i = 0; i < n_mel + 2; ++i) {
|
||||
hz_pts[i] = mel_to_hz(mel_pts[i]);
|
||||
}
|
||||
|
||||
// Slaney scale (matches librosa default)
|
||||
const double min_log_hz = 1000.0;
|
||||
const double lin_slope = 3 / 200.;
|
||||
const double min_log_mel = min_log_hz * lin_slope;
|
||||
const double log_step = log(6.4) / 27.0;
|
||||
auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double {
|
||||
return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step;
|
||||
};
|
||||
auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double {
|
||||
return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step);
|
||||
};
|
||||
const int n_fft_bins = n_fft / 2 + 1;
|
||||
|
||||
// infer N_fft from n_fft_bins
|
||||
const double bin_hz_step = double(sample_rate) / double(n_fft);
|
||||
// filterbank
|
||||
std::vector<float> out(n_mel * n_fft_bins, 0);
|
||||
for (int m = 0; m < n_mel; ++m) {
|
||||
const double f_left = hz_pts[m];
|
||||
const double f_center = hz_pts[m + 1];
|
||||
const double f_right = hz_pts[m + 2];
|
||||
|
||||
// mel grid: n_mel + 2 edges
|
||||
const double m_lo = hz_to_mel(fmin);
|
||||
const double m_hi = hz_to_mel(fmax);
|
||||
std::vector<double> mel_pts(n_mel + 2);
|
||||
for (int i = 0; i < n_mel + 2; ++i) {
|
||||
mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1));
|
||||
}
|
||||
const double denom_l = std::max(1e-30, f_center - f_left);
|
||||
const double denom_r = std::max(1e-30, f_right - f_center);
|
||||
const double enorm = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0;
|
||||
|
||||
// convert to Hz
|
||||
std::vector<double> hz_pts(n_mel + 2);
|
||||
for (int i = 0; i < n_mel + 2; ++i) {
|
||||
hz_pts[i] = mel_to_hz(mel_pts[i]);
|
||||
}
|
||||
|
||||
const int n_fft_bins = n_fft / 2 + 1;
|
||||
|
||||
// filterbank
|
||||
std::vector<float> out(n_mel * n_fft_bins, 0);
|
||||
for (int m = 0; m < n_mel; ++m) {
|
||||
const double f_left = hz_pts[m];
|
||||
const double f_center = hz_pts[m + 1];
|
||||
const double f_right = hz_pts[m + 2];
|
||||
|
||||
const double denom_l = std::max(1e-30, f_center - f_left);
|
||||
const double denom_r = std::max(1e-30, f_right - f_center);
|
||||
const double enorm = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0;
|
||||
|
||||
for (int k = 0; k < n_fft_bins; ++k) {
|
||||
const double f = k * bin_hz_step;
|
||||
double w = 0.0;
|
||||
if (f >= f_left && f <= f_center) {
|
||||
w = (f - f_left) / denom_l;
|
||||
} else if (f > f_center && f <= f_right) {
|
||||
w = (f_right - f) / denom_r;
|
||||
}
|
||||
out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale);
|
||||
for (int k = 0; k < n_fft_bins; ++k) {
|
||||
const double f = k * bin_hz_step;
|
||||
double w = 0.0;
|
||||
if (f >= f_left && f <= f_center) {
|
||||
w = (f - f_left) / denom_l;
|
||||
} else if (f > f_center && f <= f_right) {
|
||||
w = (f_right - f) / denom_r;
|
||||
}
|
||||
out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale);
|
||||
}
|
||||
}
|
||||
|
||||
filters.n_mel = n_mel;
|
||||
filters.n_fft = n_fft;
|
||||
filters.data = std::move(out);
|
||||
filters.n_mel = n_mel;
|
||||
filters.n_fft = n_fft;
|
||||
filters.data = std::move(out);
|
||||
|
||||
if (DEBUG) { // debug
|
||||
for (size_t i = 0; i < filters.data.size(); ++i) {
|
||||
if (filters.data[i] != 0.0f) {
|
||||
printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f);
|
||||
}
|
||||
if (DEBUG) { // debug
|
||||
for (size_t i = 0; i < filters.data.size(); ++i) {
|
||||
if (filters.data[i] != 0.0f) {
|
||||
printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
} g_cache;
|
||||
}
|
||||
|
||||
// naive Discrete Fourier Transform
|
||||
// input is real-valued
|
||||
// output is complex-valued
|
||||
static void dft(const float * in, int N, float * out) {
|
||||
const int n_sin_cos_vals = g_cache.sin_vals.size();
|
||||
const int sin_cos_step = n_sin_cos_vals / N;
|
||||
// Unified DFT implementation for both forward and inverse transforms
|
||||
// Template parameters:
|
||||
// Inverse: false = DFT with exp(-2πi·k·n/N), no scaling
|
||||
// true = IDFT with exp(+2πi·k·n/N), scales by 1/N
|
||||
// RealInput: true = input is real-valued (stride 1), avoids imaginary computations
|
||||
// false = input is complex-valued (interleaved real/imag, stride 2)
|
||||
template <bool Inverse, bool RealInput>
|
||||
static void dft_impl(const mtmd_audio_cache & cache, const float * in, int N, float * out) {
|
||||
const int n_sin_cos_vals = cache.sin_vals.size();
|
||||
const int sin_cos_step = n_sin_cos_vals / N;
|
||||
|
||||
constexpr float sign = Inverse ? 1.0f : -1.0f;
|
||||
const float scale = Inverse ? (1.0f / N) : 1.0f;
|
||||
|
||||
for (int k = 0; k < N; k++) {
|
||||
float re = 0;
|
||||
float im = 0;
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
int idx = (k * n * sin_cos_step) % (n_sin_cos_vals); // t = 2*M_PI*k*n/N
|
||||
re += in[n] * g_cache.cos_vals[idx]; // cos(t)
|
||||
im -= in[n] * g_cache.sin_vals[idx]; // sin(t)
|
||||
int idx = (k * n * sin_cos_step) % n_sin_cos_vals;
|
||||
float cos_val = cache.cos_vals[idx];
|
||||
float sin_val = cache.sin_vals[idx];
|
||||
|
||||
if constexpr (RealInput) {
|
||||
// Real input: in_im = 0, simplifies to:
|
||||
// re += in_re * cos_val
|
||||
// im += sign * in_re * sin_val
|
||||
float in_re = in[n];
|
||||
re += in_re * cos_val;
|
||||
im += sign * in_re * sin_val;
|
||||
} else {
|
||||
float in_re = in[n * 2 + 0];
|
||||
float in_im = in[n * 2 + 1];
|
||||
// (a + bi) * (cos + sign*i*sin) = (a*cos - sign*b*sin) + (sign*a*sin + b*cos)i
|
||||
re += in_re * cos_val - sign * in_im * sin_val;
|
||||
im += sign * in_re * sin_val + in_im * cos_val;
|
||||
}
|
||||
}
|
||||
|
||||
out[k*2 + 0] = re;
|
||||
out[k*2 + 1] = im;
|
||||
out[k * 2 + 0] = re * scale;
|
||||
out[k * 2 + 1] = im * scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Cooley-Tukey FFT
|
||||
// poor man's implementation - use something better
|
||||
// input is real-valued
|
||||
// output is complex-valued
|
||||
static void fft(float * in, int N, float * out) {
|
||||
const int n_sin_cos_vals = g_cache.sin_vals.size();
|
||||
// Cooley-Tukey FFT/IFFT unified implementation
|
||||
// Template parameters:
|
||||
// Inverse: false = FFT with exp(-2πi·k/N), no scaling
|
||||
// true = IFFT with exp(+2πi·k/N), scales by 0.5 at each level
|
||||
// RealInput: true = input is real-valued (stride 1)
|
||||
// false = input is complex-valued (interleaved real/imag, stride 2)
|
||||
template <bool Inverse, bool RealInput>
|
||||
static void fft_impl(const mtmd_audio_cache & cache, float * in, int N, float * out) {
|
||||
const int n_sin_cos_vals = cache.sin_vals.size();
|
||||
|
||||
if (N == 1) {
|
||||
out[0] = in[0];
|
||||
out[1] = 0;
|
||||
if constexpr (RealInput) {
|
||||
out[1] = 0.0f;
|
||||
} else {
|
||||
out[1] = in[1];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const int half_N = N / 2;
|
||||
if (N - half_N*2 == 1) {
|
||||
dft(in, N, out);
|
||||
if (N - half_N * 2 == 1) {
|
||||
// Odd N: fall back to DFT
|
||||
dft_impl<Inverse, RealInput>(cache, in, N, out);
|
||||
return;
|
||||
}
|
||||
|
||||
float* even = in + N;
|
||||
for (int i = 0; i < half_N; ++i) {
|
||||
even[i]= in[2*i];
|
||||
}
|
||||
float* even_fft = out + 2 * N;
|
||||
fft(even, half_N, even_fft);
|
||||
// Split into even and odd
|
||||
if constexpr (RealInput) {
|
||||
// Real input: stride is 1, copy only real values
|
||||
float * even = in + N;
|
||||
for (int i = 0; i < half_N; ++i) {
|
||||
even[i] = in[2 * i];
|
||||
}
|
||||
float * even_fft = out + 2 * N;
|
||||
fft_impl<Inverse, true>(cache, even, half_N, even_fft);
|
||||
|
||||
float* odd = even;
|
||||
for (int i = 0; i < half_N; ++i) {
|
||||
odd[i] = in[2*i + 1];
|
||||
float * odd = even;
|
||||
for (int i = 0; i < half_N; ++i) {
|
||||
odd[i] = in[2 * i + 1];
|
||||
}
|
||||
float * odd_fft = even_fft + N;
|
||||
fft_impl<Inverse, true>(cache, odd, half_N, odd_fft);
|
||||
} else {
|
||||
// Complex input: stride is 2, copy complex pairs
|
||||
float * even = in + N * 2;
|
||||
for (int i = 0; i < half_N; ++i) {
|
||||
even[i * 2 + 0] = in[2 * i * 2 + 0];
|
||||
even[i * 2 + 1] = in[2 * i * 2 + 1];
|
||||
}
|
||||
float * even_fft = out + 2 * N;
|
||||
fft_impl<Inverse, false>(cache, even, half_N, even_fft);
|
||||
|
||||
float * odd = even;
|
||||
for (int i = 0; i < half_N; ++i) {
|
||||
odd[i * 2 + 0] = in[(2 * i + 1) * 2 + 0];
|
||||
odd[i * 2 + 1] = in[(2 * i + 1) * 2 + 1];
|
||||
}
|
||||
float * odd_fft = even_fft + N;
|
||||
fft_impl<Inverse, false>(cache, odd, half_N, odd_fft);
|
||||
}
|
||||
float* odd_fft = even_fft + N;
|
||||
fft(odd, half_N, odd_fft);
|
||||
|
||||
float * even_fft = out + 2 * N;
|
||||
float * odd_fft = even_fft + N;
|
||||
|
||||
const int sin_cos_step = n_sin_cos_vals / N;
|
||||
|
||||
constexpr float sign = Inverse ? 1.0f : -1.0f;
|
||||
constexpr float scale = Inverse ? 0.5f : 1.0f;
|
||||
|
||||
for (int k = 0; k < half_N; k++) {
|
||||
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
|
||||
float re = g_cache.cos_vals[idx]; // cos(t)
|
||||
float im = -g_cache.sin_vals[idx]; // sin(t)
|
||||
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
|
||||
float re = cache.cos_vals[idx];
|
||||
float im = sign * cache.sin_vals[idx];
|
||||
|
||||
float re_odd = odd_fft[2*k + 0];
|
||||
float im_odd = odd_fft[2*k + 1];
|
||||
float re_odd = odd_fft[2 * k + 0];
|
||||
float im_odd = odd_fft[2 * k + 1];
|
||||
|
||||
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
|
||||
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
|
||||
out[2 * k + 0] = scale * (even_fft[2 * k + 0] + re * re_odd - im * im_odd);
|
||||
out[2 * k + 1] = scale * (even_fft[2 * k + 1] + re * im_odd + im * re_odd);
|
||||
|
||||
out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
|
||||
out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
|
||||
out[2 * (k + half_N) + 0] = scale * (even_fft[2 * k + 0] - re * re_odd + im * im_odd);
|
||||
out[2 * (k + half_N) + 1] = scale * (even_fft[2 * k + 1] - re * im_odd - im * re_odd);
|
||||
}
|
||||
}
|
||||
|
||||
// Forward FFT for real input (used by mel spectrogram)
|
||||
static void fft(const mtmd_audio_cache & cache, float * in, int N, float * out) {
|
||||
fft_impl<false, true>(cache, in, N, out);
|
||||
}
|
||||
|
||||
// Inverse FFT for complex input
|
||||
static void ifft(const mtmd_audio_cache & cache, float * in, int N, float * out) {
|
||||
fft_impl<true, false>(cache, in, N, out);
|
||||
}
|
||||
|
||||
struct filter_params {
|
||||
int32_t n_mel;
|
||||
int32_t n_fft_bins;
|
||||
|
|
@ -222,20 +265,27 @@ struct filter_params {
|
|||
bool norm_per_feature = false;
|
||||
};
|
||||
|
||||
static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
|
||||
int n_samples, int frame_size, int frame_step, int n_threads,
|
||||
const filter_params & params, mtmd_audio_mel & out) {
|
||||
static void log_mel_spectrogram_worker_thread(int ith,
|
||||
const float * hann,
|
||||
const std::vector<float> & samples,
|
||||
int n_samples,
|
||||
int frame_size,
|
||||
int frame_step,
|
||||
int n_threads,
|
||||
const filter_params & params,
|
||||
const mtmd_audio_cache & cache,
|
||||
mtmd_audio_mel & out) {
|
||||
std::vector<float> fft_in(frame_size * 2, 0.0);
|
||||
std::vector<float> fft_out(frame_size * 2 * 2 * 2);
|
||||
|
||||
int n_fft_bins = params.n_fft_bins;
|
||||
int i = ith;
|
||||
|
||||
const auto & filters = g_cache.filters;
|
||||
const auto & filters = cache.filters;
|
||||
|
||||
// make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
|
||||
GGML_ASSERT(n_fft_bins == 1 + (frame_size / 2));
|
||||
GGML_ASSERT(g_cache.sin_vals.size() == g_cache.cos_vals.size());
|
||||
GGML_ASSERT(cache.sin_vals.size() == cache.cos_vals.size());
|
||||
// calculate FFT only when fft_in are not all zero
|
||||
for (; i < std::min(n_samples / frame_step + 1, out.n_len); i += n_threads) {
|
||||
const int offset = i * frame_step;
|
||||
|
|
@ -251,7 +301,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const
|
|||
}
|
||||
|
||||
// FFT
|
||||
fft(fft_in.data(), frame_size, fft_out.data());
|
||||
fft(cache, fft_in.data(), frame_size, fft_out.data());
|
||||
|
||||
// Calculate modulus^2 of complex numbers
|
||||
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
|
||||
|
|
@ -298,6 +348,7 @@ static bool log_mel_spectrogram(
|
|||
const int n_samples_in,
|
||||
const int n_threads,
|
||||
const filter_params & params,
|
||||
const mtmd_audio_cache & cache,
|
||||
mtmd_audio_mel & out) {
|
||||
//const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
|
|
@ -305,9 +356,9 @@ static bool log_mel_spectrogram(
|
|||
int n_samples = n_samples_in;
|
||||
|
||||
// Hann window
|
||||
const float * hann = g_cache.hann_window.data();
|
||||
const int frame_size = (params.n_fft_bins - 1) * 2;
|
||||
const int frame_step = params.hop_length;
|
||||
const float * hann = cache.hann_window.data();
|
||||
const int frame_size = (params.n_fft_bins - 1) * 2;
|
||||
const int frame_step = params.hop_length;
|
||||
|
||||
// Padding
|
||||
std::vector<float> samples_padded;
|
||||
|
|
@ -335,9 +386,9 @@ static bool log_mel_spectrogram(
|
|||
|
||||
// preemphasis
|
||||
if (params.preemph) {
|
||||
const int pad_amount = frame_size / 2;
|
||||
const int pad_amount = frame_size / 2;
|
||||
const float preemph = 0.97f;
|
||||
float prev = samples_padded[pad_amount];
|
||||
float prev = samples_padded[pad_amount];
|
||||
for (int i = pad_amount + 1; i + pad_amount < n_samples; ++i) {
|
||||
float cur = samples_padded[i];
|
||||
samples_padded[i] = cur - preemph * prev;
|
||||
|
|
@ -372,14 +423,14 @@ static bool log_mel_spectrogram(
|
|||
{
|
||||
std::vector<std::thread> workers(n_threads - 1);
|
||||
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
||||
workers[iw] = std::thread(
|
||||
log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
|
||||
n_samples, frame_size, frame_step, n_threads,
|
||||
std::cref(params), std::ref(out));
|
||||
workers[iw] =
|
||||
std::thread(log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), n_samples,
|
||||
frame_size, frame_step, n_threads, std::cref(params), std::cref(cache), std::ref(out));
|
||||
}
|
||||
|
||||
// main thread
|
||||
log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params, out);
|
||||
log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params,
|
||||
cache, out);
|
||||
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
||||
workers[iw].join();
|
||||
}
|
||||
|
|
@ -404,7 +455,7 @@ static bool log_mel_spectrogram(
|
|||
|
||||
for (int j = 0; j < effective_n_len; ++j) {
|
||||
auto &value = out.data[i * out.n_len + j];
|
||||
value = (value - mean) / mstd;
|
||||
value = (value - mean) / mstd;
|
||||
}
|
||||
|
||||
// pad the rest with zeros
|
||||
|
|
@ -450,18 +501,14 @@ static bool log_mel_spectrogram(
|
|||
//
|
||||
|
||||
void mtmd_audio_preprocessor_whisper::initialize() {
|
||||
g_cache.fill_sin_cos_table(hparams.audio_n_fft);
|
||||
g_cache.fill_hann_window(hparams.audio_window_len, true);
|
||||
g_cache.fill_mel_filterbank_matrix(
|
||||
hparams.n_mel_bins,
|
||||
hparams.audio_n_fft,
|
||||
hparams.audio_sample_rate);
|
||||
cache.fill_sin_cos_table(hparams.audio_n_fft);
|
||||
cache.fill_hann_window(hparams.audio_window_len, true);
|
||||
cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate);
|
||||
}
|
||||
|
||||
bool mtmd_audio_preprocessor_whisper::preprocess(
|
||||
const float * samples,
|
||||
size_t n_samples,
|
||||
std::vector<mtmd_audio_mel> & output) {
|
||||
bool mtmd_audio_preprocessor_whisper::preprocess(const float * samples,
|
||||
size_t n_samples,
|
||||
std::vector<mtmd_audio_mel> & output) {
|
||||
if (n_samples == 0) {
|
||||
// empty audio
|
||||
return false;
|
||||
|
|
@ -471,7 +518,7 @@ bool mtmd_audio_preprocessor_whisper::preprocess(
|
|||
// if input is too short, pad with zeros
|
||||
// this is to avoid potential issues with stage1/2 padding in log_mel_spectrogram
|
||||
// TODO: maybe handle this better
|
||||
size_t min_samples = (size_t)hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin
|
||||
size_t min_samples = (size_t) hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin
|
||||
if (n_samples < min_samples) {
|
||||
smpl.resize(min_samples, 0.0f);
|
||||
std::memcpy(smpl.data(), samples, n_samples * sizeof(float));
|
||||
|
|
@ -486,22 +533,19 @@ bool mtmd_audio_preprocessor_whisper::preprocess(
|
|||
params.hop_length = hparams.audio_hop_len;
|
||||
params.sample_rate = hparams.audio_sample_rate;
|
||||
params.center_padding = false;
|
||||
params.preemph = 0.0f; // disabled
|
||||
params.preemph = 0.0f; // disabled
|
||||
params.use_natural_log = false;
|
||||
params.norm_per_feature = false;
|
||||
|
||||
// make sure the global cache is initialized
|
||||
GGML_ASSERT(!g_cache.sin_vals.empty());
|
||||
GGML_ASSERT(!g_cache.cos_vals.empty());
|
||||
GGML_ASSERT(!g_cache.filters.data.empty());
|
||||
// make sure the cache is initialized
|
||||
GGML_ASSERT(!cache.sin_vals.empty());
|
||||
GGML_ASSERT(!cache.cos_vals.empty());
|
||||
GGML_ASSERT(!cache.filters.data.empty());
|
||||
|
||||
mtmd_audio_mel out_full;
|
||||
bool ok = log_mel_spectrogram(
|
||||
samples,
|
||||
n_samples,
|
||||
4, // n_threads
|
||||
params,
|
||||
out_full);
|
||||
bool ok = log_mel_spectrogram(samples, n_samples,
|
||||
4, // n_threads
|
||||
params, cache, out_full);
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -512,21 +556,21 @@ bool mtmd_audio_preprocessor_whisper::preprocess(
|
|||
printf("output: n_mel = %d, n_len = %d\n", out_full.n_mel, out_full.n_len);
|
||||
}
|
||||
const size_t frames_per_chunk = 3000;
|
||||
GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk);
|
||||
for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) {
|
||||
int n_len = std::min(frames_per_chunk, (size_t)out_full.n_len - off);
|
||||
if ((size_t)n_len < frames_per_chunk) {
|
||||
break; // last uncomplete chunk will always be a padded chunk, safe to ignore
|
||||
GGML_ASSERT((size_t) out_full.n_len > frames_per_chunk);
|
||||
for (size_t off = 0; off < (size_t) out_full.n_len; off += frames_per_chunk) {
|
||||
int n_len = std::min(frames_per_chunk, (size_t) out_full.n_len - off);
|
||||
if ((size_t) n_len < frames_per_chunk) {
|
||||
break; // last uncomplete chunk will always be a padded chunk, safe to ignore
|
||||
}
|
||||
|
||||
mtmd_audio_mel out_chunk;
|
||||
out_chunk.n_len = n_len;
|
||||
out_chunk.n_mel = out_full.n_mel;
|
||||
out_chunk.n_len_org = out_full.n_mel; // unused
|
||||
out_chunk.n_len_org = out_full.n_mel; // unused
|
||||
out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len);
|
||||
|
||||
for (int i = 0; i < out_full.n_mel; i++) {
|
||||
auto src = out_full.data.begin() + i*out_full.n_len + off;
|
||||
auto src = out_full.data.begin() + i * out_full.n_len + off;
|
||||
out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk);
|
||||
}
|
||||
|
||||
|
|
@ -541,18 +585,14 @@ bool mtmd_audio_preprocessor_whisper::preprocess(
|
|||
//
|
||||
|
||||
void mtmd_audio_preprocessor_conformer::initialize() {
|
||||
g_cache.fill_sin_cos_table(hparams.audio_n_fft);
|
||||
g_cache.fill_hann_window(hparams.audio_window_len, true);
|
||||
g_cache.fill_mel_filterbank_matrix(
|
||||
hparams.n_mel_bins,
|
||||
hparams.audio_n_fft,
|
||||
hparams.audio_sample_rate);
|
||||
cache.fill_sin_cos_table(hparams.audio_n_fft);
|
||||
cache.fill_hann_window(hparams.audio_window_len, true);
|
||||
cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate);
|
||||
}
|
||||
|
||||
bool mtmd_audio_preprocessor_conformer::preprocess(
|
||||
const float * samples,
|
||||
size_t n_samples,
|
||||
std::vector<mtmd_audio_mel> & output) {
|
||||
bool mtmd_audio_preprocessor_conformer::preprocess(const float * samples,
|
||||
size_t n_samples,
|
||||
std::vector<mtmd_audio_mel> & output) {
|
||||
// empty audio
|
||||
if (n_samples == 0) {
|
||||
return false;
|
||||
|
|
@ -569,18 +609,15 @@ bool mtmd_audio_preprocessor_conformer::preprocess(
|
|||
params.use_natural_log = true;
|
||||
params.norm_per_feature = true;
|
||||
|
||||
// make sure the global cache is initialized
|
||||
GGML_ASSERT(!g_cache.sin_vals.empty());
|
||||
GGML_ASSERT(!g_cache.cos_vals.empty());
|
||||
GGML_ASSERT(!g_cache.filters.data.empty());
|
||||
// make sure the cache is initialized
|
||||
GGML_ASSERT(!cache.sin_vals.empty());
|
||||
GGML_ASSERT(!cache.cos_vals.empty());
|
||||
GGML_ASSERT(!cache.filters.data.empty());
|
||||
|
||||
mtmd_audio_mel out_full;
|
||||
bool ok = log_mel_spectrogram(
|
||||
samples,
|
||||
n_samples,
|
||||
4, // n_threads
|
||||
params,
|
||||
out_full);
|
||||
bool ok = log_mel_spectrogram(samples, n_samples,
|
||||
4, // n_threads
|
||||
params, cache, out_full);
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -588,3 +625,106 @@ bool mtmd_audio_preprocessor_conformer::preprocess(
|
|||
output.push_back(std::move(out_full));
|
||||
return true;
|
||||
}
|
||||
|
||||
//
|
||||
// mtmd_audio_streaming_istft implementation
|
||||
//
|
||||
|
||||
mtmd_audio_streaming_istft::mtmd_audio_streaming_istft(int n_fft, int hop_length) :
|
||||
n_fft(n_fft),
|
||||
hop_length(hop_length),
|
||||
n_fft_bins(n_fft / 2 + 1),
|
||||
overlap_buffer(n_fft, 0.0f),
|
||||
window_sum_buffer(n_fft, 0.0f),
|
||||
padding_to_remove((n_fft - hop_length) / 2),
|
||||
ifft_in(n_fft * 2 * 4, 0.0f), // extra space for recursive IFFT
|
||||
ifft_out(n_fft * 2 * 4, 0.0f) {
|
||||
cache.fill_sin_cos_table(n_fft);
|
||||
cache.fill_hann_window(n_fft, true);
|
||||
}
|
||||
|
||||
void mtmd_audio_streaming_istft::reset() {
|
||||
std::fill(overlap_buffer.begin(), overlap_buffer.end(), 0.0f);
|
||||
std::fill(window_sum_buffer.begin(), window_sum_buffer.end(), 0.0f);
|
||||
padding_to_remove = (n_fft - hop_length) / 2;
|
||||
}
|
||||
|
||||
std::vector<float> mtmd_audio_streaming_istft::process_frame(const float * frame_spectrum) {
|
||||
std::vector<float> output(hop_length);
|
||||
|
||||
// copy frequencies
|
||||
for (int j = 0; j < n_fft_bins; j++) {
|
||||
ifft_in[j * 2 + 0] = frame_spectrum[j * 2 + 0];
|
||||
ifft_in[j * 2 + 1] = frame_spectrum[j * 2 + 1];
|
||||
}
|
||||
|
||||
// mirror negative frequencies
|
||||
for (int j = 1; j < n_fft_bins - 1; j++) {
|
||||
int mirror_idx = n_fft - j;
|
||||
ifft_in[mirror_idx * 2 + 0] = ifft_in[j * 2 + 0];
|
||||
ifft_in[mirror_idx * 2 + 1] = -ifft_in[j * 2 + 1]; // conjugate
|
||||
}
|
||||
|
||||
ifft(cache, ifft_in.data(), n_fft, ifft_out.data());
|
||||
|
||||
// update window sum and overlap buffer
|
||||
for (int j = 0; j < n_fft; j++) {
|
||||
window_sum_buffer[j] += cache.hann_window[j] * cache.hann_window[j];
|
||||
overlap_buffer[j] += ifft_out[j * 2] * cache.hann_window[j];
|
||||
}
|
||||
|
||||
// extract hop_length samples with normalization
|
||||
for (int i = 0; i < hop_length; i++) {
|
||||
if (window_sum_buffer[i] > 1e-8f) {
|
||||
output[i] = overlap_buffer[i] / window_sum_buffer[i];
|
||||
} else {
|
||||
output[i] = overlap_buffer[i];
|
||||
}
|
||||
}
|
||||
|
||||
// shift buffers left by hop_length
|
||||
std::copy(overlap_buffer.begin() + hop_length, overlap_buffer.end(), overlap_buffer.begin());
|
||||
std::fill(overlap_buffer.end() - hop_length, overlap_buffer.end(), 0.0f);
|
||||
|
||||
std::copy(window_sum_buffer.begin() + hop_length, window_sum_buffer.end(), window_sum_buffer.begin());
|
||||
std::fill(window_sum_buffer.end() - hop_length, window_sum_buffer.end(), 0.0f);
|
||||
|
||||
// Remove padding if needed
|
||||
int to_remove = std::min(padding_to_remove, (int) output.size());
|
||||
padding_to_remove -= to_remove;
|
||||
output.erase(output.begin(), output.begin() + to_remove);
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<float> mtmd_audio_streaming_istft::flush() {
|
||||
std::vector<float> output;
|
||||
|
||||
// Extract remaining samples from overlap buffer
|
||||
// Continue until we've extracted all meaningful samples
|
||||
int remaining = n_fft - hop_length;
|
||||
while (remaining > 0) {
|
||||
int chunk_size = std::min(remaining, hop_length);
|
||||
|
||||
for (int i = 0; i < chunk_size; i++) {
|
||||
float sample;
|
||||
if (window_sum_buffer[i] > 1e-8f) {
|
||||
sample = overlap_buffer[i] / window_sum_buffer[i];
|
||||
} else {
|
||||
sample = overlap_buffer[i];
|
||||
}
|
||||
output.push_back(sample);
|
||||
}
|
||||
|
||||
// Shift buffers
|
||||
std::copy(overlap_buffer.begin() + chunk_size, overlap_buffer.end(), overlap_buffer.begin());
|
||||
std::fill(overlap_buffer.end() - chunk_size, overlap_buffer.end(), 0.0f);
|
||||
|
||||
std::copy(window_sum_buffer.begin() + chunk_size, window_sum_buffer.end(), window_sum_buffer.begin());
|
||||
std::fill(window_sum_buffer.end() - chunk_size, window_sum_buffer.end(), 0.0f);
|
||||
|
||||
remaining -= chunk_size;
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,38 @@ struct mtmd_audio_mel {
|
|||
std::vector<float> data;
|
||||
};
|
||||
|
||||
struct mtmd_audio_mel_filters {
|
||||
int32_t n_mel;
|
||||
int32_t n_fft;
|
||||
|
||||
std::vector<float> data;
|
||||
};
|
||||
|
||||
// cache for audio processing, each processor instance owns its own cache
|
||||
struct mtmd_audio_cache {
|
||||
std::vector<float> sin_vals;
|
||||
std::vector<float> cos_vals;
|
||||
|
||||
std::vector<float> hann_window;
|
||||
|
||||
mtmd_audio_mel_filters filters;
|
||||
|
||||
void fill_sin_cos_table(int n);
|
||||
|
||||
void fill_hann_window(int length, bool periodic);
|
||||
|
||||
// Build mel filterbank matrix [n_mel × n_fft_bins] at runtime.
|
||||
// n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257.
|
||||
void fill_mel_filterbank_matrix(int n_mel,
|
||||
int n_fft,
|
||||
int sample_rate, // e.g. 16000
|
||||
float fmin = 0.0f, // e.g. 0.0
|
||||
float fmax = -1.0f, // e.g. sr/2; pass -1 for auto
|
||||
bool slaney_area_norm = true,
|
||||
float scale = 1.0f // optional extra scaling
|
||||
);
|
||||
};
|
||||
|
||||
struct mtmd_audio_preprocessor {
|
||||
const clip_hparams & hparams;
|
||||
|
||||
|
|
@ -31,10 +63,51 @@ struct mtmd_audio_preprocessor_whisper : mtmd_audio_preprocessor {
|
|||
mtmd_audio_preprocessor_whisper(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {}
|
||||
void initialize() override;
|
||||
bool preprocess(const float * samples, size_t n_samples, std::vector<mtmd_audio_mel> & output) override;
|
||||
|
||||
private:
|
||||
mtmd_audio_cache cache;
|
||||
};
|
||||
|
||||
struct mtmd_audio_preprocessor_conformer : mtmd_audio_preprocessor {
|
||||
mtmd_audio_preprocessor_conformer(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {}
|
||||
void initialize() override;
|
||||
bool preprocess(const float * samples, size_t n_samples, std::vector<mtmd_audio_mel> & output) override;
|
||||
|
||||
private:
|
||||
mtmd_audio_cache cache;
|
||||
};
|
||||
|
||||
//
|
||||
// streaming ISTFT - converts spectrogram frames back to audio one frame at a time
|
||||
//
|
||||
struct mtmd_audio_streaming_istft {
|
||||
mtmd_audio_streaming_istft(int n_fft, int hop_length);
|
||||
|
||||
// reset streaming state
|
||||
void reset();
|
||||
|
||||
// process a single STFT frame (streaming)
|
||||
// frame_spectrum: [n_fft_bins x 2] interleaved real/imag
|
||||
// returns: up to hop_length samples
|
||||
std::vector<float> process_frame(const float * frame_spectrum);
|
||||
|
||||
// flush remaining samples at end of stream
|
||||
std::vector<float> flush();
|
||||
|
||||
private:
|
||||
int n_fft;
|
||||
int hop_length;
|
||||
int n_fft_bins;
|
||||
|
||||
// Own cache for output processing
|
||||
mtmd_audio_cache cache;
|
||||
|
||||
// Streaming state
|
||||
std::vector<float> overlap_buffer;
|
||||
std::vector<float> window_sum_buffer;
|
||||
int padding_to_remove;
|
||||
|
||||
// Working buffers for IFFT
|
||||
std::vector<float> ifft_in;
|
||||
std::vector<float> ifft_out;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -814,6 +814,15 @@ json server_task_result_cmpl_final::to_json_anthropic() {
|
|||
msg.content = content;
|
||||
}
|
||||
|
||||
// thinking block comes first (Anthropic extended thinking format)
|
||||
if (!msg.reasoning_content.empty()) {
|
||||
content_blocks.push_back({
|
||||
{"type", "thinking"},
|
||||
{"thinking", msg.reasoning_content},
|
||||
{"signature", ""} // empty signature for local models (no cryptographic verification)
|
||||
});
|
||||
}
|
||||
|
||||
if (!msg.content.empty()) {
|
||||
content_blocks.push_back({
|
||||
{"type", "text"},
|
||||
|
|
@ -862,20 +871,57 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() {
|
|||
stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use";
|
||||
}
|
||||
|
||||
bool has_text = !oaicompat_msg.content.empty();
|
||||
bool has_thinking = !oaicompat_msg.reasoning_content.empty();
|
||||
bool has_text = !oaicompat_msg.content.empty();
|
||||
size_t num_tool_calls = oaicompat_msg.tool_calls.size();
|
||||
|
||||
bool text_block_started = false;
|
||||
// content block indices: thinking (0) -> text (0 or 1) -> tool_use (n+)
|
||||
size_t thinking_block_index = 0;
|
||||
size_t text_block_index = has_thinking ? 1 : 0;
|
||||
|
||||
bool thinking_block_started = false;
|
||||
bool text_block_started = false;
|
||||
std::unordered_set<size_t> tool_calls_started;
|
||||
|
||||
for (const auto & diff : oaicompat_msg_diffs) {
|
||||
// handle thinking/reasoning content
|
||||
if (!diff.reasoning_content_delta.empty()) {
|
||||
if (!thinking_block_started) {
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", thinking_block_index},
|
||||
{"content_block", {
|
||||
{"type", "thinking"},
|
||||
{"thinking", ""}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
thinking_block_started = true;
|
||||
}
|
||||
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", thinking_block_index},
|
||||
{"delta", {
|
||||
{"type", "thinking_delta"},
|
||||
{"thinking", diff.reasoning_content_delta}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
// handle regular text content
|
||||
if (!diff.content_delta.empty()) {
|
||||
if (!text_block_started) {
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", 0},
|
||||
{"index", text_block_index},
|
||||
{"content_block", {
|
||||
{"type", "text"},
|
||||
{"text", ""}
|
||||
|
|
@ -889,7 +935,7 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() {
|
|||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", 0},
|
||||
{"index", text_block_index},
|
||||
{"delta", {
|
||||
{"type", "text_delta"},
|
||||
{"text", diff.content_delta}
|
||||
|
|
@ -898,8 +944,9 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() {
|
|||
});
|
||||
}
|
||||
|
||||
// handle tool calls
|
||||
if (diff.tool_call_index != std::string::npos) {
|
||||
size_t content_block_index = (has_text ? 1 : 0) + diff.tool_call_index;
|
||||
size_t content_block_index = (has_thinking ? 1 : 0) + (has_text ? 1 : 0) + diff.tool_call_index;
|
||||
|
||||
if (tool_calls_started.find(diff.tool_call_index) == tool_calls_started.end()) {
|
||||
const auto & full_tool_call = oaicompat_msg.tool_calls[diff.tool_call_index];
|
||||
|
|
@ -935,18 +982,42 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() {
|
|||
}
|
||||
}
|
||||
|
||||
// close content blocks in order
|
||||
if (has_thinking) {
|
||||
// Anthropic API requires a signature_delta before closing thinking blocks
|
||||
// We use an empty signature since we can't generate a cryptographic signature for local models
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", thinking_block_index},
|
||||
{"delta", {
|
||||
{"type", "signature_delta"},
|
||||
{"signature", ""}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
events.push_back({
|
||||
{"event", "content_block_stop"},
|
||||
{"data", {
|
||||
{"type", "content_block_stop"},
|
||||
{"index", thinking_block_index}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
if (has_text) {
|
||||
events.push_back({
|
||||
{"event", "content_block_stop"},
|
||||
{"data", {
|
||||
{"type", "content_block_stop"},
|
||||
{"index", 0}
|
||||
{"index", text_block_index}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num_tool_calls; i++) {
|
||||
size_t content_block_index = (has_text ? 1 : 0) + i;
|
||||
size_t content_block_index = (has_thinking ? 1 : 0) + (has_text ? 1 : 0) + i;
|
||||
events.push_back({
|
||||
{"event", "content_block_stop"},
|
||||
{"data", {
|
||||
|
|
@ -1154,11 +1225,10 @@ json server_task_result_rerank::to_json() {
|
|||
json server_task_result_cmpl_partial::to_json_anthropic() {
|
||||
json events = json::array();
|
||||
bool first = (n_decoded == 1);
|
||||
bool text_block_started = false;
|
||||
// use member variables to track block state across streaming calls
|
||||
// (anthropic_thinking_block_started, anthropic_text_block_started)
|
||||
|
||||
if (first) {
|
||||
text_block_started = false;
|
||||
|
||||
events.push_back({
|
||||
{"event", "message_start"},
|
||||
{"data", {
|
||||
|
|
@ -1180,28 +1250,69 @@ json server_task_result_cmpl_partial::to_json_anthropic() {
|
|||
});
|
||||
}
|
||||
|
||||
// content block indices: thinking (0) -> text (0 or 1) -> tool_use (n+)
|
||||
size_t thinking_block_index = 0;
|
||||
// use anthropic_has_reasoning (set in update()) to know if ANY reasoning was generated
|
||||
size_t text_block_index = anthropic_has_reasoning ? 1 : 0;
|
||||
|
||||
// use local copies of streaming state (copied from task_result_state in update())
|
||||
// these reflect the state BEFORE this chunk was processed
|
||||
bool thinking_started = anthropic_thinking_block_started;
|
||||
bool text_started = anthropic_text_block_started;
|
||||
|
||||
for (const auto & diff : oaicompat_msg_diffs) {
|
||||
if (!diff.content_delta.empty()) {
|
||||
if (!text_block_started) {
|
||||
// handle thinking/reasoning content
|
||||
if (!diff.reasoning_content_delta.empty()) {
|
||||
if (!thinking_started) {
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", 0},
|
||||
{"index", thinking_block_index},
|
||||
{"content_block", {
|
||||
{"type", "text"},
|
||||
{"text", ""}
|
||||
{"type", "thinking"},
|
||||
{"thinking", ""}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
text_block_started = true;
|
||||
thinking_started = true;
|
||||
}
|
||||
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", 0},
|
||||
{"index", thinking_block_index},
|
||||
{"delta", {
|
||||
{"type", "thinking_delta"},
|
||||
{"thinking", diff.reasoning_content_delta}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
// handle regular text content
|
||||
if (!diff.content_delta.empty()) {
|
||||
if (!text_started) {
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", text_block_index},
|
||||
{"content_block", {
|
||||
{"type", "text"},
|
||||
{"text", ""}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
text_started = true;
|
||||
}
|
||||
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", text_block_index},
|
||||
{"delta", {
|
||||
{"type", "text_delta"},
|
||||
{"text", diff.content_delta}
|
||||
|
|
@ -1210,8 +1321,10 @@ json server_task_result_cmpl_partial::to_json_anthropic() {
|
|||
});
|
||||
}
|
||||
|
||||
// handle tool calls
|
||||
if (diff.tool_call_index != std::string::npos) {
|
||||
size_t content_block_index = (text_block_started ? 1 : 0) + diff.tool_call_index;
|
||||
// use anthropic_has_reasoning for thinking block count (persists across calls)
|
||||
size_t content_block_index = (anthropic_has_reasoning ? 1 : 0) + (text_started ? 1 : 0) + diff.tool_call_index;
|
||||
|
||||
if (!diff.tool_call_delta.name.empty()) {
|
||||
events.push_back({
|
||||
|
|
|
|||
|
|
@ -96,6 +96,10 @@ struct task_result_state {
|
|||
std::string generated_text; // append new chunks of generated text here
|
||||
std::vector<std::string> generated_tool_call_ids;
|
||||
|
||||
// for Anthropic API streaming: track content block state across chunks
|
||||
bool anthropic_thinking_block_started = false;
|
||||
bool anthropic_text_block_started = false;
|
||||
|
||||
task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
|
||||
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}
|
||||
|
||||
|
|
@ -337,6 +341,12 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
|
||||
bool is_updated = false;
|
||||
|
||||
// for Anthropic API: track if any reasoning content has been generated
|
||||
bool anthropic_has_reasoning = false;
|
||||
// Streaming state copied from task_result_state for this chunk
|
||||
bool anthropic_thinking_block_started = false;
|
||||
bool anthropic_text_block_started = false;
|
||||
|
||||
virtual bool is_stop() override {
|
||||
return false; // in stream mode, partial responses are not considered stop
|
||||
}
|
||||
|
|
@ -346,6 +356,22 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
virtual void update(task_result_state & state) override {
|
||||
is_updated = true;
|
||||
state.update_chat_msg(content, true, oaicompat_msg_diffs);
|
||||
// track if the accumulated message has any reasoning content
|
||||
anthropic_has_reasoning = !state.chat_msg.reasoning_content.empty();
|
||||
|
||||
// Copy current state for use in to_json_anthropic() (reflects state BEFORE this chunk)
|
||||
anthropic_thinking_block_started = state.anthropic_thinking_block_started;
|
||||
anthropic_text_block_started = state.anthropic_text_block_started;
|
||||
|
||||
// Pre-compute state updates based on diffs (for next chunk)
|
||||
for (const auto & diff : oaicompat_msg_diffs) {
|
||||
if (!diff.reasoning_content_delta.empty() && !state.anthropic_thinking_block_started) {
|
||||
state.anthropic_thinking_block_started = true;
|
||||
}
|
||||
if (!diff.content_delta.empty() && !state.anthropic_text_block_started) {
|
||||
state.anthropic_text_block_started = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
json to_json_non_oaicompat();
|
||||
|
|
|
|||
|
|
@ -805,3 +805,92 @@ def test_anthropic_vs_openai_different_response_format():
|
|||
assert "input_tokens" in anthropic_res.body["usage"]
|
||||
assert "completion_tokens" in openai_res.body["usage"]
|
||||
assert "output_tokens" in anthropic_res.body["usage"]
|
||||
|
||||
|
||||
# Extended thinking tests with reasoning models
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [False, True])
|
||||
def test_anthropic_thinking_with_reasoning_model(stream):
|
||||
"""Test that thinking content blocks are properly returned for reasoning models"""
|
||||
global server
|
||||
server = ServerProcess()
|
||||
server.model_hf_repo = "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF"
|
||||
server.model_hf_file = "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf"
|
||||
server.reasoning_format = "deepseek"
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192
|
||||
server.n_predict = 1024
|
||||
server.server_port = 8084
|
||||
server.start(timeout_seconds=600) # large model needs time to download
|
||||
|
||||
if stream:
|
||||
res = server.make_stream_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 1024,
|
||||
"thinking": {
|
||||
"type": "enabled",
|
||||
"budget_tokens": 500
|
||||
},
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is 2+2?"}
|
||||
],
|
||||
"stream": True
|
||||
})
|
||||
|
||||
events = list(res)
|
||||
|
||||
# should have thinking content block events
|
||||
thinking_starts = [e for e in events if
|
||||
e.get("type") == "content_block_start" and
|
||||
e.get("content_block", {}).get("type") == "thinking"]
|
||||
assert len(thinking_starts) > 0, "Should have thinking content_block_start event"
|
||||
assert thinking_starts[0]["index"] == 0, "Thinking block should be at index 0"
|
||||
|
||||
# should have thinking_delta events
|
||||
thinking_deltas = [e for e in events if
|
||||
e.get("type") == "content_block_delta" and
|
||||
e.get("delta", {}).get("type") == "thinking_delta"]
|
||||
assert len(thinking_deltas) > 0, "Should have thinking_delta events"
|
||||
|
||||
# should have signature_delta event before thinking block closes (Anthropic API requirement)
|
||||
signature_deltas = [e for e in events if
|
||||
e.get("type") == "content_block_delta" and
|
||||
e.get("delta", {}).get("type") == "signature_delta"]
|
||||
assert len(signature_deltas) > 0, "Should have signature_delta event for thinking block"
|
||||
|
||||
# should have text block after thinking
|
||||
text_starts = [e for e in events if
|
||||
e.get("type") == "content_block_start" and
|
||||
e.get("content_block", {}).get("type") == "text"]
|
||||
assert len(text_starts) > 0, "Should have text content_block_start event"
|
||||
assert text_starts[0]["index"] == 1, "Text block should be at index 1 (after thinking)"
|
||||
else:
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 1024,
|
||||
"thinking": {
|
||||
"type": "enabled",
|
||||
"budget_tokens": 500
|
||||
},
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is 2+2?"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
|
||||
content = res.body["content"]
|
||||
assert len(content) >= 2, "Should have at least thinking and text blocks"
|
||||
|
||||
# first block should be thinking
|
||||
thinking_blocks = [b for b in content if b.get("type") == "thinking"]
|
||||
assert len(thinking_blocks) > 0, "Should have thinking content block"
|
||||
assert "thinking" in thinking_blocks[0], "Thinking block should have 'thinking' field"
|
||||
assert len(thinking_blocks[0]["thinking"]) > 0, "Thinking content should not be empty"
|
||||
assert "signature" in thinking_blocks[0], "Thinking block should have 'signature' field (Anthropic API requirement)"
|
||||
|
||||
# should also have text block
|
||||
text_blocks = [b for b in content if b.get("type") == "text"]
|
||||
assert len(text_blocks) > 0, "Should have text content block"
|
||||
|
|
|
|||
Loading…
Reference in New Issue