context : add set_result_logits to cast from compute_type back to F32 for the final logits output

This commit is contained in:
shaobo.xie 2026-02-12 12:50:55 +08:00
parent f47e50a18b
commit d38923a1ec
2 changed files with 13 additions and 1 deletions

View File

@ -880,6 +880,12 @@ ggml_tensor * llm_graph_context::build_cast_to_f32(
return ggml_cast(ctx, cur, GGML_TYPE_F32);
}
ggml_tensor * llm_graph_context::set_result_logits(ggml_tensor * cur) {
cur = build_cast_to_f32(ctx0, cur);
res->t_logits = cur;
return cur;
}
ggml_tensor * llm_graph_context::build_cvec(
ggml_tensor * cur,
int il) const {
@ -1544,6 +1550,9 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
// ref: https://github.com/ggml-org/llama.cpp/pull/18599
ggml_build_forward_expand(gf, cur);
// cast to compute_type if needed (e.g., F16 for intermediate activations)
cur = build_cast_to_compute_type(ctx0, cur);
return cur;
}

View File

@ -756,7 +756,8 @@ struct llm_graph_context {
void cb(ggml_tensor * cur, const char * name, int il) const;
ggml_tensor * build_cast_to_compute_type( // intermediate computation precision.
// intermediate computation precision.
ggml_tensor * build_cast_to_compute_type(
ggml_context * ctx,
ggml_tensor * cur) const;
@ -764,6 +765,8 @@ struct llm_graph_context {
ggml_context * ctx,
ggml_tensor * cur) const;
ggml_tensor * set_result_logits(ggml_tensor * cur);
//
// common
//