This commit is contained in:
Blime 2026-02-16 16:45:58 -06:00 committed by GitHub
commit a2070b3a9d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 132 additions and 0 deletions

View File

@ -1332,6 +1332,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("error: unknown value for --flash-attn: '%s'\n", value.c_str()));
}
}).set_env("LLAMA_ARG_FLASH_ATTN"));
add_opt(common_arg({ "-ct", "--compute-type" }, "[f32|f16|bf16|default]",
string_format("set intermediate computation precision ('f32', 'f16', 'bf16' or 'default', default: '%s')",
llama_compute_type_name(params.compute_type)),
[](common_params & params, const std::string & value) {
if (value == "f32") {
params.compute_type = LLAMA_COMPUTE_TYPE_F32;
} else if (value == "f16") {
params.compute_type = LLAMA_COMPUTE_TYPE_F16;
} else if (value == "bf16") {
params.compute_type = LLAMA_COMPUTE_TYPE_BF16;
} else if (value == "default") {
params.compute_type = LLAMA_COMPUTE_TYPE_DEFAULT;
} else {
throw std::runtime_error(
string_format("error: unknown value for --compute-type: '%s'\n", value.c_str()));
}
}).set_env("LLAMA_ARG_COMPUTE_TYPE"));
add_opt(common_arg(
{"-p", "--prompt"}, "PROMPT",
"prompt to start generation with; for system message, use -sys",

View File

@ -1397,6 +1397,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.pooling_type = params.pooling_type;
cparams.attention_type = params.attention_type;
cparams.flash_attn_type = params.flash_attn_type;
cparams.compute_type = params.compute_type;
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.offload_kqv = !params.no_kv_offload;

View File

@ -402,6 +402,7 @@ struct common_params {
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
enum llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; // whether to use Flash Attention
enum llama_compute_type compute_type = LLAMA_COMPUTE_TYPE_DEFAULT; // intermediate computation precision
struct common_params_sampling sampling;
struct common_params_speculative speculative;

View File

@ -188,6 +188,15 @@ extern "C" {
LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type);
enum llama_compute_type {
LLAMA_COMPUTE_TYPE_DEFAULT = 0, // no override, use model's native precision.
LLAMA_COMPUTE_TYPE_F32 = 1,
LLAMA_COMPUTE_TYPE_F16 = 2,
LLAMA_COMPUTE_TYPE_BF16 = 3,
};
LLAMA_API const char * llama_compute_type_name(enum llama_compute_type compute_type);
enum llama_split_mode {
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
@ -336,6 +345,7 @@ extern "C" {
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
enum llama_attention_type attention_type; // attention type to use for embeddings
enum llama_flash_attn_type flash_attn_type; // when to enable Flash Attention
enum llama_compute_type compute_type; // intermediate activation precision [EXPERIMENTAL]
// ref: https://github.com/ggml-org/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency, 0 = from model

View File

@ -150,6 +150,53 @@ llama_context::llama_context(
cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO;
// convert public enum → internal ggml_type for intermediate computation precision
switch (params.compute_type) {
case LLAMA_COMPUTE_TYPE_F32:
cparams.compute_type = GGML_TYPE_F32;
break;
case LLAMA_COMPUTE_TYPE_F16:
cparams.compute_type = GGML_TYPE_F16;
break;
case LLAMA_COMPUTE_TYPE_BF16:
cparams.compute_type = GGML_TYPE_BF16;
break;
case LLAMA_COMPUTE_TYPE_DEFAULT:
default:
// DEFAULT = no override, use model's native precision (F32 for now)
cparams.compute_type = GGML_TYPE_F32;
break;
}
// Nowadays FP16 and BF16 support is model-specific.
// Add models here as their required ops are 'compute_type' implemented and validated.
auto model_supports_compute_type = [&](ggml_type ct) -> bool {
if (ct == GGML_TYPE_F32) {
return true; // F32 is always supported
}
// Example (uncomment when ready):
// if (ct == GGML_TYPE_F16 || ct == GGML_TYPE_BF16) {
// switch (model.arch) {
// case LLM_ARCH_QWEN2:
// case LLM_ARCH_QWEN2MOE:
// case LLM_ARCH_QWEN3:
// return true;
// default:
// return false;
// }
// }
(void)model.arch; // no models enabled yet for non-F32 compute types
return false;
};
if (!model_supports_compute_type(cparams.compute_type)) {
LLAMA_LOG_WARN("%s: model arch '%s' does not yet support compute_type %s, "
"falling back to F32. To enable, the required ops must be implemented first.\n",
__func__, llm_arch_name(model.arch),
ggml_type_name(cparams.compute_type));
cparams.compute_type = GGML_TYPE_F32;
}
// with causal attention, the batch size is limited by the context size
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
@ -196,6 +243,7 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type));
LLAMA_LOG_INFO("%s: compute_type = %s\n", __func__, llama_compute_type_name(params.compute_type));
LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
@ -2942,6 +2990,7 @@ llama_context_params llama_context_default_params() {
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
/*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
/*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO,
/*.compute_type =*/ LLAMA_COMPUTE_TYPE_DEFAULT,
/*.rope_freq_base =*/ 0.0f,
/*.rope_freq_scale =*/ 0.0f,
/*.yarn_ext_factor =*/ -1.0f,

View File

@ -39,6 +39,8 @@ struct llama_cparams {
enum llama_pooling_type pooling_type;
ggml_type compute_type;
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
};

View File

@ -888,6 +888,30 @@ void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
}
}
ggml_tensor * llm_graph_context::build_cast_to_compute_type(
ggml_context * ctx,
ggml_tensor * cur) const {
if (cparams.compute_type == GGML_TYPE_F32 || cur->type == cparams.compute_type) {
return cur;
}
return ggml_cast(ctx, cur, cparams.compute_type);
}
ggml_tensor * llm_graph_context::build_cast_to_f32(
ggml_context * ctx,
ggml_tensor * cur) const {
if (cur->type == GGML_TYPE_F32) {
return cur;
}
return ggml_cast(ctx, cur, GGML_TYPE_F32);
}
ggml_tensor * llm_graph_context::set_result_logits(ggml_tensor * cur) {
cur = build_cast_to_f32(ctx0, cur);
res->t_logits = cur;
return cur;
}
ggml_tensor * llm_graph_context::build_cvec(
ggml_tensor * cur,
int il) const {
@ -1552,6 +1576,9 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
// ref: https://github.com/ggml-org/llama.cpp/pull/18599
ggml_build_forward_expand(gf, cur);
// cast to compute_type if needed (e.g., F16 for intermediate activations)
cur = build_cast_to_compute_type(ctx0, cur);
return cur;
}

View File

@ -756,6 +756,17 @@ struct llm_graph_context {
void cb(ggml_tensor * cur, const char * name, int il) const;
// intermediate computation precision.
ggml_tensor * build_cast_to_compute_type(
ggml_context * ctx,
ggml_tensor * cur) const;
ggml_tensor * build_cast_to_f32(
ggml_context * ctx,
ggml_tensor * cur) const;
ggml_tensor * set_result_logits(ggml_tensor * cur);
//
// common
//

View File

@ -43,6 +43,20 @@ const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_ty
GGML_ABORT("fatal error");
}
const char * llama_compute_type_name(enum llama_compute_type compute_type) {
switch (compute_type) {
case LLAMA_COMPUTE_TYPE_DEFAULT:
return "default";
case LLAMA_COMPUTE_TYPE_F32:
return "f32";
case LLAMA_COMPUTE_TYPE_F16:
return "f16";
case LLAMA_COMPUTE_TYPE_BF16:
return "bf16";
}
GGML_ABORT("fatal error");
}
struct llama_device_memory_data {
int64_t total;
int64_t free;