context : add set_result_logits to cast from compute_type back to F32 for the final logits output
This commit is contained in:
parent
f47e50a18b
commit
d38923a1ec
|
|
@ -880,6 +880,12 @@ ggml_tensor * llm_graph_context::build_cast_to_f32(
|
|||
return ggml_cast(ctx, cur, GGML_TYPE_F32);
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::set_result_logits(ggml_tensor * cur) {
|
||||
cur = build_cast_to_f32(ctx0, cur);
|
||||
res->t_logits = cur;
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_cvec(
|
||||
ggml_tensor * cur,
|
||||
int il) const {
|
||||
|
|
@ -1544,6 +1550,9 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
|||
// ref: https://github.com/ggml-org/llama.cpp/pull/18599
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
// cast to compute_type if needed (e.g., F16 for intermediate activations)
|
||||
cur = build_cast_to_compute_type(ctx0, cur);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -756,7 +756,8 @@ struct llm_graph_context {
|
|||
|
||||
void cb(ggml_tensor * cur, const char * name, int il) const;
|
||||
|
||||
ggml_tensor * build_cast_to_compute_type( // intermediate computation precision.
|
||||
// intermediate computation precision.
|
||||
ggml_tensor * build_cast_to_compute_type(
|
||||
ggml_context * ctx,
|
||||
ggml_tensor * cur) const;
|
||||
|
||||
|
|
@ -764,6 +765,8 @@ struct llm_graph_context {
|
|||
ggml_context * ctx,
|
||||
ggml_tensor * cur) const;
|
||||
|
||||
ggml_tensor * set_result_logits(ggml_tensor * cur);
|
||||
|
||||
//
|
||||
// common
|
||||
//
|
||||
|
|
|
|||
Loading…
Reference in New Issue