diff --git a/common/arg.cpp b/common/arg.cpp index 18f953a38e..8eca51641a 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1332,6 +1332,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("error: unknown value for --flash-attn: '%s'\n", value.c_str())); } }).set_env("LLAMA_ARG_FLASH_ATTN")); + add_opt(common_arg({ "-ct", "--compute-type" }, "[f32|f16|bf16|default]", + string_format("set intermediate computation precision ('f32', 'f16', 'bf16' or 'default', default: '%s')", + llama_compute_type_name(params.compute_type)), + [](common_params & params, const std::string & value) { + if (value == "f32") { + params.compute_type = LLAMA_COMPUTE_TYPE_F32; + } else if (value == "f16") { + params.compute_type = LLAMA_COMPUTE_TYPE_F16; + } else if (value == "bf16") { + params.compute_type = LLAMA_COMPUTE_TYPE_BF16; + } else if (value == "default") { + params.compute_type = LLAMA_COMPUTE_TYPE_DEFAULT; + } else { + throw std::runtime_error( + string_format("error: unknown value for --compute-type: '%s'\n", value.c_str())); + } + }).set_env("LLAMA_ARG_COMPUTE_TYPE")); add_opt(common_arg( {"-p", "--prompt"}, "PROMPT", "prompt to start generation with; for system message, use -sys", diff --git a/common/common.cpp b/common/common.cpp index 32487ddc61..846c46d805 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1397,6 +1397,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.pooling_type = params.pooling_type; cparams.attention_type = params.attention_type; cparams.flash_attn_type = params.flash_attn_type; + cparams.compute_type = params.compute_type; cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; diff --git a/common/common.h b/common/common.h index 6410248377..4f6d5fe37f 100644 --- a/common/common.h +++ b/common/common.h @@ -402,6 +402,7 @@ struct common_params { enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings enum llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; // whether to use Flash Attention + enum llama_compute_type compute_type = LLAMA_COMPUTE_TYPE_DEFAULT; // intermediate computation precision struct common_params_sampling sampling; struct common_params_speculative speculative; diff --git a/include/llama.h b/include/llama.h index d2d7f59ebc..703050cf88 100644 --- a/include/llama.h +++ b/include/llama.h @@ -188,6 +188,15 @@ extern "C" { LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type); + enum llama_compute_type { + LLAMA_COMPUTE_TYPE_DEFAULT = 0, // no override, use model's native precision. + LLAMA_COMPUTE_TYPE_F32 = 1, + LLAMA_COMPUTE_TYPE_F16 = 2, + LLAMA_COMPUTE_TYPE_BF16 = 3, + }; + + LLAMA_API const char * llama_compute_type_name(enum llama_compute_type compute_type); + enum llama_split_mode { LLAMA_SPLIT_MODE_NONE = 0, // single GPU LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs @@ -336,6 +345,7 @@ extern "C" { enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_attention_type attention_type; // attention type to use for embeddings enum llama_flash_attn_type flash_attn_type; // when to enable Flash Attention + enum llama_compute_type compute_type; // intermediate activation precision [EXPERIMENTAL] // ref: https://github.com/ggml-org/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency, 0 = from model diff --git a/src/llama-context.cpp b/src/llama-context.cpp index fc05989aa5..05353fbbe3 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -150,6 +150,53 @@ llama_context::llama_context( cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO; + // convert public enum → internal ggml_type for intermediate computation precision + switch (params.compute_type) { + case LLAMA_COMPUTE_TYPE_F32: + cparams.compute_type = GGML_TYPE_F32; + break; + case LLAMA_COMPUTE_TYPE_F16: + cparams.compute_type = GGML_TYPE_F16; + break; + case LLAMA_COMPUTE_TYPE_BF16: + cparams.compute_type = GGML_TYPE_BF16; + break; + case LLAMA_COMPUTE_TYPE_DEFAULT: + default: + // DEFAULT = no override, use model's native precision (F32 for now) + cparams.compute_type = GGML_TYPE_F32; + break; + } + + // Nowadays FP16 and BF16 support is model-specific. + // Add models here as their required ops are 'compute_type' implemented and validated. + auto model_supports_compute_type = [&](ggml_type ct) -> bool { + if (ct == GGML_TYPE_F32) { + return true; // F32 is always supported + } + // Example (uncomment when ready): + // if (ct == GGML_TYPE_F16 || ct == GGML_TYPE_BF16) { + // switch (model.arch) { + // case LLM_ARCH_QWEN2: + // case LLM_ARCH_QWEN2MOE: + // case LLM_ARCH_QWEN3: + // return true; + // default: + // return false; + // } + // } + (void)model.arch; // no models enabled yet for non-F32 compute types + return false; + }; + + if (!model_supports_compute_type(cparams.compute_type)) { + LLAMA_LOG_WARN("%s: model arch '%s' does not yet support compute_type %s, " + "falling back to F32. To enable, the required ops must be implemented first.\n", + __func__, llm_arch_name(model.arch), + ggml_type_name(cparams.compute_type)); + cparams.compute_type = GGML_TYPE_F32; + } + // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; @@ -196,6 +243,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type)); + LLAMA_LOG_INFO("%s: compute_type = %s\n", __func__, llama_compute_type_name(params.compute_type)); LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -2942,6 +2990,7 @@ llama_context_params llama_context_default_params() { /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, /*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO, + /*.compute_type =*/ LLAMA_COMPUTE_TYPE_DEFAULT, /*.rope_freq_base =*/ 0.0f, /*.rope_freq_scale =*/ 0.0f, /*.yarn_ext_factor =*/ -1.0f, diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 2da3bbd6f9..8a37b36e75 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -39,6 +39,8 @@ struct llama_cparams { enum llama_pooling_type pooling_type; + ggml_type compute_type; + ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; }; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 70d8ff02a9..5885601d0d 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -888,6 +888,30 @@ void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const { } } +ggml_tensor * llm_graph_context::build_cast_to_compute_type( + ggml_context * ctx, + ggml_tensor * cur) const { + if (cparams.compute_type == GGML_TYPE_F32 || cur->type == cparams.compute_type) { + return cur; + } + return ggml_cast(ctx, cur, cparams.compute_type); +} + +ggml_tensor * llm_graph_context::build_cast_to_f32( + ggml_context * ctx, + ggml_tensor * cur) const { + if (cur->type == GGML_TYPE_F32) { + return cur; + } + return ggml_cast(ctx, cur, GGML_TYPE_F32); +} + +ggml_tensor * llm_graph_context::set_result_logits(ggml_tensor * cur) { + cur = build_cast_to_f32(ctx0, cur); + res->t_logits = cur; + return cur; +} + ggml_tensor * llm_graph_context::build_cvec( ggml_tensor * cur, int il) const { @@ -1552,6 +1576,9 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { // ref: https://github.com/ggml-org/llama.cpp/pull/18599 ggml_build_forward_expand(gf, cur); + // cast to compute_type if needed (e.g., F16 for intermediate activations) + cur = build_cast_to_compute_type(ctx0, cur); + return cur; } diff --git a/src/llama-graph.h b/src/llama-graph.h index 1d69ff1a6f..920aef80cc 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -756,6 +756,17 @@ struct llm_graph_context { void cb(ggml_tensor * cur, const char * name, int il) const; + // intermediate computation precision. + ggml_tensor * build_cast_to_compute_type( + ggml_context * ctx, + ggml_tensor * cur) const; + + ggml_tensor * build_cast_to_f32( + ggml_context * ctx, + ggml_tensor * cur) const; + + ggml_tensor * set_result_logits(ggml_tensor * cur); + // // common // diff --git a/src/llama.cpp b/src/llama.cpp index 6da90d6f1f..51cfe18c60 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -43,6 +43,20 @@ const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_ty GGML_ABORT("fatal error"); } +const char * llama_compute_type_name(enum llama_compute_type compute_type) { + switch (compute_type) { + case LLAMA_COMPUTE_TYPE_DEFAULT: + return "default"; + case LLAMA_COMPUTE_TYPE_F32: + return "f32"; + case LLAMA_COMPUTE_TYPE_F16: + return "f16"; + case LLAMA_COMPUTE_TYPE_BF16: + return "bf16"; + } + GGML_ABORT("fatal error"); +} + struct llama_device_memory_data { int64_t total; int64_t free;