cont : fixes

This commit is contained in:
Georgi Gerganov 2025-12-07 12:52:25 +02:00
parent 52258181da
commit 8ef5f900db
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
8 changed files with 52 additions and 23 deletions

View File

@ -1286,7 +1286,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sampling.top_k = value;
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K;
}
).set_sparam());
).set_sparam().set_env("LLAMA_ARG_TOP_K"));
add_opt(common_arg(
{"--top-p"}, "N",
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),

View File

@ -149,6 +149,7 @@ static __global__ void cumsum_kernel(
}
}
#ifdef GGML_CUDA_USE_CUB
template <typename T>
static void cumsum_cub(ggml_cuda_pool & pool,
const T * src,
@ -171,6 +172,7 @@ static void cumsum_cub(ggml_cuda_pool & pool,
// Perform the inclusive scan
cub::DeviceScan::InclusiveSum((void *) tmp_alloc.get(), tmp_size, src, dst, ne, stream);
}
#endif // GGML_CUDA_USE_CUB
template<typename T>
static void cumsum_cuda(
@ -188,7 +190,7 @@ static void cumsum_cuda(
if (is_contiguous) {
use_cub = true;
int64_t nrows = ne01 * ne02 * ne03;
int64_t nrows = ne01 * ne02 * ne03;
// TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released
// Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004
if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) {

View File

@ -537,6 +537,38 @@ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
}
}
void llm_graph_result::set_outputs() {
if (t_logits != nullptr) {
ggml_set_output(t_logits);
}
if (t_embd != nullptr) {
ggml_set_output(t_embd);
}
if (t_embd_pooled != nullptr) {
ggml_set_output(t_embd_pooled);
}
for (auto & [seq_id, t] : t_sampled) {
if (t != nullptr) {
ggml_set_output(t);
}
}
for (auto & [seq_id, t] : t_sampled_probs) {
if (t != nullptr) {
ggml_set_output(t);
}
}
for (auto & [seq_id, t] : t_sampled_logits) {
if (t != nullptr) {
ggml_set_output(t);
}
}
for (auto & [seq_id, t] : t_candidates) {
if (t != nullptr) {
ggml_set_output(t);
}
}
}
bool llm_graph_result::can_reuse(const llm_graph_params & params) {
if (!this->params.allow_reuse(params)) {
if (debug > 1) {
@ -2100,25 +2132,21 @@ void llm_graph_context::build_sampling() const {
sampler->iface->backend_apply(sampler, ctx0, gf, &data);
if (data.sampled != nullptr) {
ggml_set_output(data.sampled);
res->t_sampled[seq_id] = data.sampled;
ggml_build_forward_expand(gf, data.sampled);
}
if (data.probs != nullptr) {
ggml_set_output(data.probs);
res->t_sampled_probs[seq_id] = data.probs;
ggml_build_forward_expand(gf, data.probs);
}
if (data.logits != nullptr) {
ggml_set_output(data.logits);
res->t_sampled_logits[seq_id] = data.logits;
ggml_build_forward_expand(gf, data.logits);
}
if (data.candidates != nullptr) {
ggml_set_output(data.candidates);
res->t_candidates[seq_id] = data.candidates;
ggml_build_forward_expand(gf, data.candidates);
}

View File

@ -523,6 +523,7 @@ public:
void reset();
void set_inputs(const llama_ubatch * ubatch);
void set_outputs();
// try to update the existing graph result using the new graph parameters in order to reuse it
// this can only be done if we determine that the resulting graph using the new graph parameters

View File

@ -7581,6 +7581,8 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
// TODO: move reranking logic here and generalize
llm->build_dense_out(dense_2_out_layers, dense_3_out_layers);
llm->res->set_outputs();
return llm->res->get_gf();
}

View File

@ -1367,7 +1367,7 @@ static void llama_sampler_top_p_backend_apply(
struct llama_sampler_data * data) {
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
auto ggml_sort = [& ctx](struct ggml_tensor * a, struct ggml_tensor * b) {
auto ggml_sort = [ctx](struct ggml_tensor * a, struct ggml_tensor * b) {
GGML_ASSERT(ggml_nrows(a) == 1);
struct ggml_tensor * a_reshaped = ggml_reshape_2d(ctx, a, 1, a->ne[0]);
struct ggml_tensor * a_sorted = ggml_get_rows(ctx, a_reshaped, b);
@ -1386,7 +1386,7 @@ static void llama_sampler_top_p_backend_apply(
ggml_set_name(softmax, "top_p_softmax");
// If candidates are provided, sort them as well. Otherwise, set sorted indices as candidates.
if (data->candidates != nullptr) {
if (data->candidates) {
data->candidates = ggml_sort(data->candidates, sorted_idx);
} else {
data->candidates = sorted_idx;
@ -1412,8 +1412,9 @@ static void llama_sampler_top_p_backend_apply(
// Make top-p inclusive (i.e. return all values such that cum_sum/cdf >= p)
struct ggml_tensor * mask_reshaped = ggml_reshape_2d(ctx, mask, 1, mask->ne[0]);
// construct ones tensor to set the value in the mask
struct ggml_tensor * ones = ggml_dup_tensor(ctx, mask_reshaped);
ones = ggml_clamp(ctx, ones, 1.0f, 1.0f);
struct ggml_tensor * ones = ggml_clamp(ctx, mask_reshaped, 1.0f, 1.0f);
ggml_set_name(ones, "top_p_ones");
mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, ggml_repeat(ctx, idxf, mask), GGML_TYPE_I32));
mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]);
@ -1780,12 +1781,7 @@ static void llama_sampler_backend_temp_sampling(
return;
}
struct ggml_tensor * scaled = ggml_scale(ctx, data->logits, 1.0f / temp);
ggml_set_name(scaled, "temp_scaled");
// Make sure the scaled tensor is contiguous for subsequent operations
data->logits = ggml_cont(ctx, scaled);
ggml_set_name(data->logits, "temp_scaled_logits");
data->logits = ggml_scale(ctx, data->logits, 1.0f / temp);
GGML_UNUSED(gf);
}
@ -3278,7 +3274,7 @@ static void llama_sampler_logit_bias_backend_apply(
}
// Add the sparse logit logit_bias to the logits
struct ggml_tensor * logit_biased = ggml_add_inplace(ctx, data->logits, sctx->inp_logit_bias);
struct ggml_tensor * logit_biased = ggml_add(ctx, data->logits, sctx->inp_logit_bias);
data->logits = logit_biased;
}

View File

@ -18,7 +18,7 @@ def create_server():
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.|Timmy", 77, 8, "length", False, None),
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.|Timmy", 77, 8, "length", True, None),
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.|Timmy", 77, 8, "length", True, 'chatml'),
(None, "Book", "What is the best book", 8, "^ blue|very teaful", 23, 8, "length", True, "This is not a chat template, it is"),
(None, "Book", "What is the best book", 8, "^ blue|very teaful|very busy", 23, 8, "length", True, "This is not a chat template, it is"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger|shake)+", 104, 128, "length", False, None),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger|shake)+", 104, 128, "length", True, None),
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter|Some", 79, 8, "length", False, None),

View File

@ -17,7 +17,7 @@ def create_server():
server = ServerPreset.tinyllama2()
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [
("I believe the meaning of life is", 8, "(going|bed)+|froze and every", 18, 8, False, False),
("I believe the meaning of life is", 8, "(going|bed)+|froze and every|froze and bri", 18, 8, False, False),
("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True),
])
def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool):
@ -42,7 +42,7 @@ def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int,
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
("I believe the meaning of life is", 8, "(going|bed)+|froze and every", 18, 8, False),
("I believe the meaning of life is", 8, "(going|bed)+|froze and every|froze and bri", 18, 8, False),
("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
])
def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
@ -103,7 +103,7 @@ def test_completion_with_openai_library():
assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
assert res.choices[0].finish_reason == "length"
assert res.choices[0].text is not None
assert match_regex("(going|bed)+|froze and every", res.choices[0].text)
assert match_regex("(going|bed)+|froze and every|froze and bri", res.choices[0].text)
def test_completion_stream_with_openai_library():
@ -122,7 +122,7 @@ def test_completion_stream_with_openai_library():
if choice.finish_reason is None:
assert choice.text is not None
output_text += choice.text
assert match_regex("(going|bed)+|froze and every", output_text)
assert match_regex("(going|bed)+|froze and every|froze and bri", output_text)
# Test case from https://github.com/ggml-org/llama.cpp/issues/13780
@ -146,7 +146,7 @@ def test_completion_stream_with_openai_library_stops():
if choice.finish_reason is None:
assert choice.text is not None
output_text += choice.text
assert match_regex("Sure, here's one for[\\s\\S]*|Sure thing..Why don't", output_text), f'Unexpected output: {output_text}'
assert match_regex("Sure, here's one for[\\s\\S]*|Sure thing..Why don't|Sure! Here's one for you:", output_text), f'Unexpected output: {output_text}'
@pytest.mark.parametrize("n_slots", [1, 2])