From e30dc63cb608b9e33cafdb61b1d2ec19c6ff1297 Mon Sep 17 00:00:00 2001 From: "shaobo.xie" Date: Thu, 12 Feb 2026 11:40:56 +0800 Subject: [PATCH] context : add compute-type to set intermediate computation precision --- common/arg.cpp | 17 +++++++++++++++++ common/common.cpp | 1 + common/common.h | 1 + include/llama.h | 10 ++++++++++ src/llama-context.cpp | 20 ++++++++++++++++++++ src/llama-cparams.h | 2 ++ src/llama-graph.cpp | 18 ++++++++++++++++++ src/llama-graph.h | 8 ++++++++ src/llama.cpp | 14 ++++++++++++++ 9 files changed, 91 insertions(+) diff --git a/common/arg.cpp b/common/arg.cpp index 9c85696ebd..28a1f68a80 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1332,6 +1332,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("error: unknown value for --flash-attn: '%s'\n", value.c_str())); } }).set_env("LLAMA_ARG_FLASH_ATTN")); + add_opt(common_arg({ "-ct", "--compute-type" }, "[f32|f16|bf16|default]", + string_format("set intermediate computation precision ('f32', 'f16', 'bf16' or 'default', default: '%s')", + llama_compute_type_name(params.compute_type)), + [](common_params & params, const std::string & value) { + if (value == "f32") { + params.compute_type = LLAMA_COMPUTE_TYPE_F32; + } else if (value == "f16") { + params.compute_type = LLAMA_COMPUTE_TYPE_F16; + } else if (value == "bf16") { + params.compute_type = LLAMA_COMPUTE_TYPE_BF16; + } else if (value == "default") { + params.compute_type = LLAMA_COMPUTE_TYPE_DEFAULT; + } else { + throw std::runtime_error( + string_format("error: unknown value for --compute-type: '%s'\n", value.c_str())); + } + }).set_env("LLAMA_ARG_COMPUTE_TYPE")); add_opt(common_arg( {"-p", "--prompt"}, "PROMPT", "prompt to start generation with; for system message, use -sys", diff --git a/common/common.cpp b/common/common.cpp index 93474f88c7..6baec1705d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1412,6 +1412,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.pooling_type = params.pooling_type; cparams.attention_type = params.attention_type; cparams.flash_attn_type = params.flash_attn_type; + cparams.compute_type = params.compute_type; cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; diff --git a/common/common.h b/common/common.h index 804485fb19..1d8fb06d5b 100644 --- a/common/common.h +++ b/common/common.h @@ -402,6 +402,7 @@ struct common_params { enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings enum llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; // whether to use Flash Attention + enum llama_compute_type compute_type = LLAMA_COMPUTE_TYPE_DEFAULT; // intermediate computation precision struct common_params_sampling sampling; struct common_params_speculative speculative; diff --git a/include/llama.h b/include/llama.h index 46c3672e98..542c07e163 100644 --- a/include/llama.h +++ b/include/llama.h @@ -188,6 +188,15 @@ extern "C" { LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type); + enum llama_compute_type { + LLAMA_COMPUTE_TYPE_DEFAULT = 0, // no override, use model's native precision. + LLAMA_COMPUTE_TYPE_F32 = 1, + LLAMA_COMPUTE_TYPE_F16 = 2, + LLAMA_COMPUTE_TYPE_BF16 = 3, + }; + + LLAMA_API const char * llama_compute_type_name(enum llama_compute_type compute_type); + enum llama_split_mode { LLAMA_SPLIT_MODE_NONE = 0, // single GPU LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs @@ -336,6 +345,7 @@ extern "C" { enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_attention_type attention_type; // attention type to use for embeddings enum llama_flash_attn_type flash_attn_type; // when to enable Flash Attention + enum llama_compute_type compute_type; // intermediate activation precision [EXPERIMENTAL] // ref: https://github.com/ggml-org/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency, 0 = from model diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 6b43ca1926..af021984cc 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -148,6 +148,24 @@ llama_context::llama_context( cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO; + // convert public enum → internal ggml_type for intermediate computation precision + switch (params.compute_type) { + case LLAMA_COMPUTE_TYPE_F32: + cparams.compute_type = GGML_TYPE_F32; + break; + case LLAMA_COMPUTE_TYPE_F16: + cparams.compute_type = GGML_TYPE_F16; + break; + case LLAMA_COMPUTE_TYPE_BF16: + cparams.compute_type = GGML_TYPE_BF16; + break; + case LLAMA_COMPUTE_TYPE_DEFAULT: + default: + // DEFAULT = no override, use model's native precision (F32 for now) + cparams.compute_type = GGML_TYPE_F32; + break; + } + // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; @@ -194,6 +212,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type)); + LLAMA_LOG_INFO("%s: compute_type = %s\n", __func__, llama_compute_type_name(params.compute_type)); LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -2951,6 +2970,7 @@ llama_context_params llama_context_default_params() { /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, /*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO, + /*.compute_type =*/ LLAMA_COMPUTE_TYPE_DEFAULT, /*.rope_freq_base =*/ 0.0f, /*.rope_freq_scale =*/ 0.0f, /*.yarn_ext_factor =*/ -1.0f, diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 2da3bbd6f9..8a37b36e75 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -39,6 +39,8 @@ struct llama_cparams { enum llama_pooling_type pooling_type; + ggml_type compute_type; + ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; }; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index bba747d37b..1bee470264 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -862,6 +862,24 @@ void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const { } } +ggml_tensor * llm_graph_context::build_cast_to_compute_type( + ggml_context * ctx, + ggml_tensor * cur) const { + if (cparams.compute_type == GGML_TYPE_F32 || cur->type == cparams.compute_type) { + return cur; + } + return ggml_cast(ctx, cur, cparams.compute_type); +} + +ggml_tensor * llm_graph_context::build_cast_to_f32( + ggml_context * ctx, + ggml_tensor * cur) const { + if (cur->type == GGML_TYPE_F32) { + return cur; + } + return ggml_cast(ctx, cur, GGML_TYPE_F32); +} + ggml_tensor * llm_graph_context::build_cvec( ggml_tensor * cur, int il) const { diff --git a/src/llama-graph.h b/src/llama-graph.h index 1d69ff1a6f..ce440ef76b 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -756,6 +756,14 @@ struct llm_graph_context { void cb(ggml_tensor * cur, const char * name, int il) const; + ggml_tensor * build_cast_to_compute_type( // intermediate computation precision. + ggml_context * ctx, + ggml_tensor * cur) const; + + ggml_tensor * build_cast_to_f32( + ggml_context * ctx, + ggml_tensor * cur) const; + // // common // diff --git a/src/llama.cpp b/src/llama.cpp index 6da90d6f1f..51cfe18c60 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -43,6 +43,20 @@ const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_ty GGML_ABORT("fatal error"); } +const char * llama_compute_type_name(enum llama_compute_type compute_type) { + switch (compute_type) { + case LLAMA_COMPUTE_TYPE_DEFAULT: + return "default"; + case LLAMA_COMPUTE_TYPE_F32: + return "f32"; + case LLAMA_COMPUTE_TYPE_F16: + return "f16"; + case LLAMA_COMPUTE_TYPE_BF16: + return "bf16"; + } + GGML_ABORT("fatal error"); +} + struct llama_device_memory_data { int64_t total; int64_t free;