From e30dc63cb608b9e33cafdb61b1d2ec19c6ff1297 Mon Sep 17 00:00:00 2001 From: "shaobo.xie" Date: Thu, 12 Feb 2026 11:40:56 +0800 Subject: [PATCH 1/4] context : add compute-type to set intermediate computation precision --- common/arg.cpp | 17 +++++++++++++++++ common/common.cpp | 1 + common/common.h | 1 + include/llama.h | 10 ++++++++++ src/llama-context.cpp | 20 ++++++++++++++++++++ src/llama-cparams.h | 2 ++ src/llama-graph.cpp | 18 ++++++++++++++++++ src/llama-graph.h | 8 ++++++++ src/llama.cpp | 14 ++++++++++++++ 9 files changed, 91 insertions(+) diff --git a/common/arg.cpp b/common/arg.cpp index 9c85696ebd..28a1f68a80 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1332,6 +1332,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("error: unknown value for --flash-attn: '%s'\n", value.c_str())); } }).set_env("LLAMA_ARG_FLASH_ATTN")); + add_opt(common_arg({ "-ct", "--compute-type" }, "[f32|f16|bf16|default]", + string_format("set intermediate computation precision ('f32', 'f16', 'bf16' or 'default', default: '%s')", + llama_compute_type_name(params.compute_type)), + [](common_params & params, const std::string & value) { + if (value == "f32") { + params.compute_type = LLAMA_COMPUTE_TYPE_F32; + } else if (value == "f16") { + params.compute_type = LLAMA_COMPUTE_TYPE_F16; + } else if (value == "bf16") { + params.compute_type = LLAMA_COMPUTE_TYPE_BF16; + } else if (value == "default") { + params.compute_type = LLAMA_COMPUTE_TYPE_DEFAULT; + } else { + throw std::runtime_error( + string_format("error: unknown value for --compute-type: '%s'\n", value.c_str())); + } + }).set_env("LLAMA_ARG_COMPUTE_TYPE")); add_opt(common_arg( {"-p", "--prompt"}, "PROMPT", "prompt to start generation with; for system message, use -sys", diff --git a/common/common.cpp b/common/common.cpp index 93474f88c7..6baec1705d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1412,6 +1412,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.pooling_type = params.pooling_type; cparams.attention_type = params.attention_type; cparams.flash_attn_type = params.flash_attn_type; + cparams.compute_type = params.compute_type; cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; diff --git a/common/common.h b/common/common.h index 804485fb19..1d8fb06d5b 100644 --- a/common/common.h +++ b/common/common.h @@ -402,6 +402,7 @@ struct common_params { enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings enum llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; // whether to use Flash Attention + enum llama_compute_type compute_type = LLAMA_COMPUTE_TYPE_DEFAULT; // intermediate computation precision struct common_params_sampling sampling; struct common_params_speculative speculative; diff --git a/include/llama.h b/include/llama.h index 46c3672e98..542c07e163 100644 --- a/include/llama.h +++ b/include/llama.h @@ -188,6 +188,15 @@ extern "C" { LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type); + enum llama_compute_type { + LLAMA_COMPUTE_TYPE_DEFAULT = 0, // no override, use model's native precision. + LLAMA_COMPUTE_TYPE_F32 = 1, + LLAMA_COMPUTE_TYPE_F16 = 2, + LLAMA_COMPUTE_TYPE_BF16 = 3, + }; + + LLAMA_API const char * llama_compute_type_name(enum llama_compute_type compute_type); + enum llama_split_mode { LLAMA_SPLIT_MODE_NONE = 0, // single GPU LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs @@ -336,6 +345,7 @@ extern "C" { enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_attention_type attention_type; // attention type to use for embeddings enum llama_flash_attn_type flash_attn_type; // when to enable Flash Attention + enum llama_compute_type compute_type; // intermediate activation precision [EXPERIMENTAL] // ref: https://github.com/ggml-org/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency, 0 = from model diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 6b43ca1926..af021984cc 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -148,6 +148,24 @@ llama_context::llama_context( cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO; + // convert public enum → internal ggml_type for intermediate computation precision + switch (params.compute_type) { + case LLAMA_COMPUTE_TYPE_F32: + cparams.compute_type = GGML_TYPE_F32; + break; + case LLAMA_COMPUTE_TYPE_F16: + cparams.compute_type = GGML_TYPE_F16; + break; + case LLAMA_COMPUTE_TYPE_BF16: + cparams.compute_type = GGML_TYPE_BF16; + break; + case LLAMA_COMPUTE_TYPE_DEFAULT: + default: + // DEFAULT = no override, use model's native precision (F32 for now) + cparams.compute_type = GGML_TYPE_F32; + break; + } + // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; @@ -194,6 +212,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type)); + LLAMA_LOG_INFO("%s: compute_type = %s\n", __func__, llama_compute_type_name(params.compute_type)); LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -2951,6 +2970,7 @@ llama_context_params llama_context_default_params() { /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, /*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO, + /*.compute_type =*/ LLAMA_COMPUTE_TYPE_DEFAULT, /*.rope_freq_base =*/ 0.0f, /*.rope_freq_scale =*/ 0.0f, /*.yarn_ext_factor =*/ -1.0f, diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 2da3bbd6f9..8a37b36e75 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -39,6 +39,8 @@ struct llama_cparams { enum llama_pooling_type pooling_type; + ggml_type compute_type; + ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; }; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index bba747d37b..1bee470264 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -862,6 +862,24 @@ void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const { } } +ggml_tensor * llm_graph_context::build_cast_to_compute_type( + ggml_context * ctx, + ggml_tensor * cur) const { + if (cparams.compute_type == GGML_TYPE_F32 || cur->type == cparams.compute_type) { + return cur; + } + return ggml_cast(ctx, cur, cparams.compute_type); +} + +ggml_tensor * llm_graph_context::build_cast_to_f32( + ggml_context * ctx, + ggml_tensor * cur) const { + if (cur->type == GGML_TYPE_F32) { + return cur; + } + return ggml_cast(ctx, cur, GGML_TYPE_F32); +} + ggml_tensor * llm_graph_context::build_cvec( ggml_tensor * cur, int il) const { diff --git a/src/llama-graph.h b/src/llama-graph.h index 1d69ff1a6f..ce440ef76b 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -756,6 +756,14 @@ struct llm_graph_context { void cb(ggml_tensor * cur, const char * name, int il) const; + ggml_tensor * build_cast_to_compute_type( // intermediate computation precision. + ggml_context * ctx, + ggml_tensor * cur) const; + + ggml_tensor * build_cast_to_f32( + ggml_context * ctx, + ggml_tensor * cur) const; + // // common // diff --git a/src/llama.cpp b/src/llama.cpp index 6da90d6f1f..51cfe18c60 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -43,6 +43,20 @@ const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_ty GGML_ABORT("fatal error"); } +const char * llama_compute_type_name(enum llama_compute_type compute_type) { + switch (compute_type) { + case LLAMA_COMPUTE_TYPE_DEFAULT: + return "default"; + case LLAMA_COMPUTE_TYPE_F32: + return "f32"; + case LLAMA_COMPUTE_TYPE_F16: + return "f16"; + case LLAMA_COMPUTE_TYPE_BF16: + return "bf16"; + } + GGML_ABORT("fatal error"); +} + struct llama_device_memory_data { int64_t total; int64_t free; From f47e50a18bff69bfe2b9445f957e649719aea77e Mon Sep 17 00:00:00 2001 From: "shaobo.xie" Date: Thu, 12 Feb 2026 12:19:44 +0800 Subject: [PATCH 2/4] context : add model_supports_compute_type check, nowadays still model-specific. --- src/llama-context.cpp | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index af021984cc..ffa219f167 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -14,6 +14,32 @@ #include #include +static bool model_supports_compute_type(enum llm_arch arch, ggml_type compute_type) { + // F32 is always supported - it's the default/safe precision + if (compute_type == GGML_TYPE_F32) { + return true; + } + + // Nowadays FP16 and BF16 support is model-specific. + // Add models here as their required ops are 'compute_type' implemented and validated. + // Example (uncomment when ready): + // if (compute_type == GGML_TYPE_F16 || compute_type == GGML_TYPE_BF16) { + // switch (arch) { + // case LLM_ARCH_QWEN2: + // case LLM_ARCH_QWEN2MOE: + // case LLM_ARCH_QWEN3: + // // ... other validated models + // return true; + // default: + // return false; + // } + // } + + // No models enabled yet for non-F32 compute types + (void)arch; + return false; +} + // // llama_context // @@ -166,6 +192,17 @@ llama_context::llama_context( break; } + // check if the model supports the requested compute type + if (cparams.compute_type != GGML_TYPE_F32) { + if (!model_supports_compute_type(model.arch, cparams.compute_type)) { + LLAMA_LOG_WARN("%s: model arch '%s' does not yet support compute_type %s, " + "falling back to F32. To enable, the required ops must be implemented first.\n", + __func__, llm_arch_name(model.arch), + ggml_type_name(cparams.compute_type)); + cparams.compute_type = GGML_TYPE_F32; + } + } + // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; From d38923a1ecd64c919767cb867f4a898d621daeec Mon Sep 17 00:00:00 2001 From: "shaobo.xie" Date: Thu, 12 Feb 2026 12:50:55 +0800 Subject: [PATCH 3/4] context : add set_result_logits to cast from compute_type back to F32 for the final logits output --- src/llama-graph.cpp | 9 +++++++++ src/llama-graph.h | 5 ++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1bee470264..488b055aa9 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -880,6 +880,12 @@ ggml_tensor * llm_graph_context::build_cast_to_f32( return ggml_cast(ctx, cur, GGML_TYPE_F32); } +ggml_tensor * llm_graph_context::set_result_logits(ggml_tensor * cur) { + cur = build_cast_to_f32(ctx0, cur); + res->t_logits = cur; + return cur; +} + ggml_tensor * llm_graph_context::build_cvec( ggml_tensor * cur, int il) const { @@ -1544,6 +1550,9 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { // ref: https://github.com/ggml-org/llama.cpp/pull/18599 ggml_build_forward_expand(gf, cur); + // cast to compute_type if needed (e.g., F16 for intermediate activations) + cur = build_cast_to_compute_type(ctx0, cur); + return cur; } diff --git a/src/llama-graph.h b/src/llama-graph.h index ce440ef76b..920aef80cc 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -756,7 +756,8 @@ struct llm_graph_context { void cb(ggml_tensor * cur, const char * name, int il) const; - ggml_tensor * build_cast_to_compute_type( // intermediate computation precision. + // intermediate computation precision. + ggml_tensor * build_cast_to_compute_type( ggml_context * ctx, ggml_tensor * cur) const; @@ -764,6 +765,8 @@ struct llm_graph_context { ggml_context * ctx, ggml_tensor * cur) const; + ggml_tensor * set_result_logits(ggml_tensor * cur); + // // common // From 233f5ab82d447e08e4bbdfc504a2fcbb0670617f Mon Sep 17 00:00:00 2001 From: "shaobo.xie" Date: Thu, 12 Feb 2026 14:06:38 +0800 Subject: [PATCH 4/4] context : refactor model_supports_compute_type to lambda keep // // llama_context // look nice. --- src/llama-context.cpp | 60 +++++++++++++++++++------------------------ 1 file changed, 26 insertions(+), 34 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ffa219f167..b49afc94b0 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -14,32 +14,6 @@ #include #include -static bool model_supports_compute_type(enum llm_arch arch, ggml_type compute_type) { - // F32 is always supported - it's the default/safe precision - if (compute_type == GGML_TYPE_F32) { - return true; - } - - // Nowadays FP16 and BF16 support is model-specific. - // Add models here as their required ops are 'compute_type' implemented and validated. - // Example (uncomment when ready): - // if (compute_type == GGML_TYPE_F16 || compute_type == GGML_TYPE_BF16) { - // switch (arch) { - // case LLM_ARCH_QWEN2: - // case LLM_ARCH_QWEN2MOE: - // case LLM_ARCH_QWEN3: - // // ... other validated models - // return true; - // default: - // return false; - // } - // } - - // No models enabled yet for non-F32 compute types - (void)arch; - return false; -} - // // llama_context // @@ -192,15 +166,33 @@ llama_context::llama_context( break; } - // check if the model supports the requested compute type - if (cparams.compute_type != GGML_TYPE_F32) { - if (!model_supports_compute_type(model.arch, cparams.compute_type)) { - LLAMA_LOG_WARN("%s: model arch '%s' does not yet support compute_type %s, " - "falling back to F32. To enable, the required ops must be implemented first.\n", - __func__, llm_arch_name(model.arch), - ggml_type_name(cparams.compute_type)); - cparams.compute_type = GGML_TYPE_F32; + // Nowadays FP16 and BF16 support is model-specific. + // Add models here as their required ops are 'compute_type' implemented and validated. + auto model_supports_compute_type = [&](ggml_type ct) -> bool { + if (ct == GGML_TYPE_F32) { + return true; // F32 is always supported } + // Example (uncomment when ready): + // if (ct == GGML_TYPE_F16 || ct == GGML_TYPE_BF16) { + // switch (model.arch) { + // case LLM_ARCH_QWEN2: + // case LLM_ARCH_QWEN2MOE: + // case LLM_ARCH_QWEN3: + // return true; + // default: + // return false; + // } + // } + (void)model.arch; // no models enabled yet for non-F32 compute types + return false; + }; + + if (!model_supports_compute_type(cparams.compute_type)) { + LLAMA_LOG_WARN("%s: model arch '%s' does not yet support compute_type %s, " + "falling back to F32. To enable, the required ops must be implemented first.\n", + __func__, llm_arch_name(model.arch), + ggml_type_name(cparams.compute_type)); + cparams.compute_type = GGML_TYPE_F32; } // with causal attention, the batch size is limited by the context size