server : fix checkpoints n_tokens calculation (#20287)
This commit is contained in:
parent
ed0007aa32
commit
96cfc4992c
|
|
@ -2141,6 +2141,9 @@ private:
|
|||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
||||
const auto & input_tokens = slot.task->tokens;
|
||||
|
||||
// used to determine the number of tokens added to the batch for the current slot
|
||||
const auto n_tokens_prev = batch.n_tokens;
|
||||
|
||||
// TODO: maybe move branch to outside of this loop in the future
|
||||
if (slot.state == SLOT_STATE_STARTED) {
|
||||
slot.t_start_process_prompt = ggml_time_us();
|
||||
|
|
@ -2533,6 +2536,9 @@ private:
|
|||
}
|
||||
}
|
||||
|
||||
// the number of tokens added to the batch for the current slot
|
||||
const auto n_tokens_cur = batch.n_tokens - n_tokens_prev;
|
||||
|
||||
// entire prompt has been processed
|
||||
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
|
||||
slot.state = SLOT_STATE_DONE_PROMPT;
|
||||
|
|
@ -2593,7 +2599,7 @@ private:
|
|||
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
|
||||
/*.pos_min = */ pos_min,
|
||||
/*.pos_max = */ pos_max,
|
||||
/*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens,
|
||||
/*.n_tokens = */ slot.prompt.n_tokens() - n_tokens_cur,
|
||||
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||
});
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue