initial ngram-mod proof of concept, score-based pruning

This commit is contained in:
Bernhard Froemel 2026-02-03 11:31:26 +00:00
parent 1f1e57f2bf
commit e543f88952
3 changed files with 149 additions and 5 deletions

View File

@ -6,6 +6,7 @@
common_ngram_mod::common_ngram_mod(uint16_t n, size_t size) : n(n), used(0) {
entries.resize(size);
scores.resize(size, SCORE_INIT);
reset();
}
@ -27,8 +28,12 @@ void common_ngram_mod::add(const entry_t * tokens) {
if (entries[i] == EMPTY) {
used++;
scores[i] = SCORE_INS;
} else if (entries[i] != tokens[n]) {
// a different token hashes to the same bucket
++collisions;
}
// keep existing score if entry already occupied
entries[i] = tokens[n];
}
@ -40,7 +45,9 @@ common_ngram_mod::entry_t common_ngram_mod::get(const entry_t * tokens) const {
void common_ngram_mod::reset() {
std::fill(entries.begin(), entries.end(), EMPTY);
std::fill(scores.begin(), scores.end(), 0);
used = 0;
collisions = 0;
}
size_t common_ngram_mod::get_n() const {
@ -56,5 +63,83 @@ size_t common_ngram_mod::size() const {
}
size_t common_ngram_mod::size_bytes() const {
return entries.size() * sizeof(entries[0]);
return entries.size() * sizeof(entries[0]) + scores.size() * sizeof(scores[0]);
}
size_t common_ngram_mod::index(const entry_t * tokens) const {
return idx(tokens);
}
void common_ngram_mod::inc_score(const entry_t * tokens) {
const size_t i = idx(tokens);
if (scores[i] < common_ngram_mod::SCORE_MAX) {
++scores[i];
}
}
void common_ngram_mod::dec_score(const entry_t * tokens) {
const size_t i = idx(tokens);
if (scores[i] > common_ngram_mod::SCORE_MIN) {
--scores[i];
}
}
void common_ngram_mod::inc_score_by_index(size_t i) {
if (i < scores.size() && scores[i] < common_ngram_mod::SCORE_MAX) {
++scores[i];
}
}
void common_ngram_mod::dec_score_by_index(size_t i) {
if (i < scores.size() && scores[i] > common_ngram_mod::SCORE_MIN) {
--scores[i];
}
}
void common_ngram_mod::prune_low_score() {
used = 0;
for (size_t i = 0; i < entries.size(); ++i) {
if (scores[i] < common_ngram_mod::SCORE_THR) {
entries[i] = EMPTY;
scores[i] = 0;
} else {
++used;
}
}
}
size_t common_ngram_mod::get_collisions() const {
return collisions;
}
size_t common_ngram_mod::get_below_thr() const {
return count_below_thr;
}
size_t common_ngram_mod::get_at_min() const {
return count_at_min;
}
size_t common_ngram_mod::get_at_max() const {
return count_at_max;
}
size_t common_ngram_mod::get_at_ins() const {
return count_at_ins;
}
void common_ngram_mod::update_score_stats() {
// reset counters
count_below_thr = 0;
count_at_min = 0;
count_at_max = 0;
count_at_ins = 0;
for (size_t i = 0; i < scores.size(); ++i) {
const int8_t s = scores[i];
if (s < SCORE_THR) ++count_below_thr;
if (s == SCORE_MIN) ++count_at_min;
if (s == SCORE_MAX) ++count_at_max;
if (s == SCORE_INS) ++count_at_ins;
}
}

View File

@ -15,6 +15,12 @@ struct common_ngram_mod {
static constexpr entry_t EMPTY = -1;
static constexpr int8_t SCORE_INIT = 0;
static constexpr int8_t SCORE_MIN = -5;
static constexpr int8_t SCORE_MAX = 20;
static constexpr int8_t SCORE_THR = 0; // keep equal or lower than SCORE_INIT
static constexpr int8_t SCORE_INS = 3;
common_ngram_mod(uint16_t n, size_t size);
size_t idx(const entry_t * tokens) const;
@ -23,9 +29,27 @@ struct common_ngram_mod {
void reset();
// expose the hash index for external bookkeeping
size_t index(const entry_t * tokens) const;
// score handling
void inc_score(const entry_t * tokens);
void dec_score(const entry_t * tokens);
void inc_score_by_index(size_t i);
void dec_score_by_index(size_t i);
void prune_low_score(); // remove entries below SCORE_THR
size_t get_n() const;
size_t get_used() const;
void update_score_stats();
size_t get_collisions() const;
size_t get_below_thr() const;
size_t get_at_min() const;
size_t get_at_max() const;
size_t get_at_ins() const;
size_t size() const;
size_t size_bytes() const;
@ -35,4 +59,15 @@ private:
size_t used;
std::vector<entry_t> entries;
// per-entry score, range SCORE_MIN .. SCORE_MAX
std::vector<int8_t> scores;
// stats
// count of hash collisions
size_t collisions = 0;
// counts for score
size_t count_below_thr = 0;
size_t count_at_min = 0;
size_t count_at_max = 0;
size_t count_at_ins = 0;
};

View File

@ -527,6 +527,8 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
// consecutive accept rounds with low acceptance fraction (< 0.5)
int n_low = 0;
// hash indices of ngrams consulted during the most recent draft
std::vector<size_t> used_hashes;
// enable trace logging if LLAMA_TRACE is set
const bool verbose;
@ -558,7 +560,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
constexpr double f_thold = 0.25;
if (f > f_thold) {
LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold);
LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting (collisions=%zu)\n", __func__, f, f_thold, mod.get_collisions());
mod.reset();
}
@ -572,6 +574,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
GGML_UNUSED(params);
n_draft_last = 0;
used_hashes.clear();
const size_t cur_len = prompt_tgt.size();
if (cur_len < mod.get_n()) {
@ -607,6 +610,8 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
break;
}
result[n + i] = token;
// remember which hash entry produced this token
used_hashes.push_back(mod.index(result.data() + i));
}
// only return the m tokens that were drafted
@ -627,18 +632,37 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
// compute acceptance fraction if we have a recorded draft length
if (n_draft_last > 0) {
const double f_acc = (double)n_accepted / (double)n_draft_last;
// update per-ngram scores based on acceptance outcome
for (size_t i = 0; i < n_draft_last; ++i) {
if (i < static_cast<size_t>(n_accepted)) {
mod.inc_score_by_index(used_hashes[i]);
} else {
mod.dec_score_by_index(used_hashes[i]);
}
}
if (f_acc < 0.5) {
n_low++;
if (n_low >= 3) {
LOG_WRN("%s: low acceptance streak (%d) resetting ngram_mod\n", __func__, n_low);
LOG_WRN("%s: low acceptance streak (%d) - pruning ngram_mod (collisions=%zu)\n", __func__, n_low, mod.get_collisions());
// Log detailed score metrics before pruning
mod.update_score_stats();
LOG_WRN("%s: before prune scores - below_thr=%zu, at_min=%zu, at_max=%zu, at_ins=%zu\n",
__func__,
mod.get_below_thr(),
mod.get_at_min(),
mod.get_at_max(),
mod.get_at_ins());
mod.reset();
mod.prune_low_score();
n_low = 0;
}
} else {
n_low = 0;
}
}
used_hashes.clear();
}
};