From d38923a1ecd64c919767cb867f4a898d621daeec Mon Sep 17 00:00:00 2001 From: "shaobo.xie" Date: Thu, 12 Feb 2026 12:50:55 +0800 Subject: [PATCH] context : add set_result_logits to cast from compute_type back to F32 for the final logits output --- src/llama-graph.cpp | 9 +++++++++ src/llama-graph.h | 5 ++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1bee470264..488b055aa9 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -880,6 +880,12 @@ ggml_tensor * llm_graph_context::build_cast_to_f32( return ggml_cast(ctx, cur, GGML_TYPE_F32); } +ggml_tensor * llm_graph_context::set_result_logits(ggml_tensor * cur) { + cur = build_cast_to_f32(ctx0, cur); + res->t_logits = cur; + return cur; +} + ggml_tensor * llm_graph_context::build_cvec( ggml_tensor * cur, int il) const { @@ -1544,6 +1550,9 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { // ref: https://github.com/ggml-org/llama.cpp/pull/18599 ggml_build_forward_expand(gf, cur); + // cast to compute_type if needed (e.g., F16 for intermediate activations) + cur = build_cast_to_compute_type(ctx0, cur); + return cur; } diff --git a/src/llama-graph.h b/src/llama-graph.h index ce440ef76b..920aef80cc 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -756,7 +756,8 @@ struct llm_graph_context { void cb(ggml_tensor * cur, const char * name, int il) const; - ggml_tensor * build_cast_to_compute_type( // intermediate computation precision. + // intermediate computation precision. + ggml_tensor * build_cast_to_compute_type( ggml_context * ctx, ggml_tensor * cur) const; @@ -764,6 +765,8 @@ struct llm_graph_context { ggml_context * ctx, ggml_tensor * cur) const; + ggml_tensor * set_result_logits(ggml_tensor * cur); + // // common //