initial ngram-mod proof of concept, score-based pruning
This commit is contained in:
parent
1f1e57f2bf
commit
e543f88952
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
common_ngram_mod::common_ngram_mod(uint16_t n, size_t size) : n(n), used(0) {
|
||||
entries.resize(size);
|
||||
scores.resize(size, SCORE_INIT);
|
||||
|
||||
reset();
|
||||
}
|
||||
|
|
@ -27,8 +28,12 @@ void common_ngram_mod::add(const entry_t * tokens) {
|
|||
|
||||
if (entries[i] == EMPTY) {
|
||||
used++;
|
||||
scores[i] = SCORE_INS;
|
||||
} else if (entries[i] != tokens[n]) {
|
||||
// a different token hashes to the same bucket
|
||||
++collisions;
|
||||
}
|
||||
|
||||
// keep existing score if entry already occupied
|
||||
entries[i] = tokens[n];
|
||||
}
|
||||
|
||||
|
|
@ -40,7 +45,9 @@ common_ngram_mod::entry_t common_ngram_mod::get(const entry_t * tokens) const {
|
|||
|
||||
void common_ngram_mod::reset() {
|
||||
std::fill(entries.begin(), entries.end(), EMPTY);
|
||||
std::fill(scores.begin(), scores.end(), 0);
|
||||
used = 0;
|
||||
collisions = 0;
|
||||
}
|
||||
|
||||
size_t common_ngram_mod::get_n() const {
|
||||
|
|
@ -56,5 +63,83 @@ size_t common_ngram_mod::size() const {
|
|||
}
|
||||
|
||||
size_t common_ngram_mod::size_bytes() const {
|
||||
return entries.size() * sizeof(entries[0]);
|
||||
return entries.size() * sizeof(entries[0]) + scores.size() * sizeof(scores[0]);
|
||||
}
|
||||
|
||||
size_t common_ngram_mod::index(const entry_t * tokens) const {
|
||||
return idx(tokens);
|
||||
}
|
||||
|
||||
void common_ngram_mod::inc_score(const entry_t * tokens) {
|
||||
const size_t i = idx(tokens);
|
||||
if (scores[i] < common_ngram_mod::SCORE_MAX) {
|
||||
++scores[i];
|
||||
}
|
||||
}
|
||||
|
||||
void common_ngram_mod::dec_score(const entry_t * tokens) {
|
||||
const size_t i = idx(tokens);
|
||||
if (scores[i] > common_ngram_mod::SCORE_MIN) {
|
||||
--scores[i];
|
||||
}
|
||||
}
|
||||
|
||||
void common_ngram_mod::inc_score_by_index(size_t i) {
|
||||
if (i < scores.size() && scores[i] < common_ngram_mod::SCORE_MAX) {
|
||||
++scores[i];
|
||||
}
|
||||
}
|
||||
|
||||
void common_ngram_mod::dec_score_by_index(size_t i) {
|
||||
if (i < scores.size() && scores[i] > common_ngram_mod::SCORE_MIN) {
|
||||
--scores[i];
|
||||
}
|
||||
}
|
||||
|
||||
void common_ngram_mod::prune_low_score() {
|
||||
used = 0;
|
||||
for (size_t i = 0; i < entries.size(); ++i) {
|
||||
if (scores[i] < common_ngram_mod::SCORE_THR) {
|
||||
entries[i] = EMPTY;
|
||||
scores[i] = 0;
|
||||
} else {
|
||||
++used;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t common_ngram_mod::get_collisions() const {
|
||||
return collisions;
|
||||
}
|
||||
|
||||
size_t common_ngram_mod::get_below_thr() const {
|
||||
return count_below_thr;
|
||||
}
|
||||
|
||||
size_t common_ngram_mod::get_at_min() const {
|
||||
return count_at_min;
|
||||
}
|
||||
|
||||
size_t common_ngram_mod::get_at_max() const {
|
||||
return count_at_max;
|
||||
}
|
||||
|
||||
size_t common_ngram_mod::get_at_ins() const {
|
||||
return count_at_ins;
|
||||
}
|
||||
|
||||
void common_ngram_mod::update_score_stats() {
|
||||
// reset counters
|
||||
count_below_thr = 0;
|
||||
count_at_min = 0;
|
||||
count_at_max = 0;
|
||||
count_at_ins = 0;
|
||||
|
||||
for (size_t i = 0; i < scores.size(); ++i) {
|
||||
const int8_t s = scores[i];
|
||||
if (s < SCORE_THR) ++count_below_thr;
|
||||
if (s == SCORE_MIN) ++count_at_min;
|
||||
if (s == SCORE_MAX) ++count_at_max;
|
||||
if (s == SCORE_INS) ++count_at_ins;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,12 @@ struct common_ngram_mod {
|
|||
|
||||
static constexpr entry_t EMPTY = -1;
|
||||
|
||||
static constexpr int8_t SCORE_INIT = 0;
|
||||
static constexpr int8_t SCORE_MIN = -5;
|
||||
static constexpr int8_t SCORE_MAX = 20;
|
||||
static constexpr int8_t SCORE_THR = 0; // keep equal or lower than SCORE_INIT
|
||||
static constexpr int8_t SCORE_INS = 3;
|
||||
|
||||
common_ngram_mod(uint16_t n, size_t size);
|
||||
|
||||
size_t idx(const entry_t * tokens) const;
|
||||
|
|
@ -23,9 +29,27 @@ struct common_ngram_mod {
|
|||
|
||||
void reset();
|
||||
|
||||
// expose the hash index for external bookkeeping
|
||||
size_t index(const entry_t * tokens) const;
|
||||
|
||||
// score handling
|
||||
void inc_score(const entry_t * tokens);
|
||||
void dec_score(const entry_t * tokens);
|
||||
void inc_score_by_index(size_t i);
|
||||
void dec_score_by_index(size_t i);
|
||||
void prune_low_score(); // remove entries below SCORE_THR
|
||||
|
||||
size_t get_n() const;
|
||||
size_t get_used() const;
|
||||
|
||||
void update_score_stats();
|
||||
|
||||
size_t get_collisions() const;
|
||||
size_t get_below_thr() const;
|
||||
size_t get_at_min() const;
|
||||
size_t get_at_max() const;
|
||||
size_t get_at_ins() const;
|
||||
|
||||
size_t size() const;
|
||||
size_t size_bytes() const;
|
||||
|
||||
|
|
@ -35,4 +59,15 @@ private:
|
|||
size_t used;
|
||||
|
||||
std::vector<entry_t> entries;
|
||||
// per-entry score, range SCORE_MIN .. SCORE_MAX
|
||||
std::vector<int8_t> scores;
|
||||
|
||||
// stats
|
||||
// count of hash collisions
|
||||
size_t collisions = 0;
|
||||
// counts for score
|
||||
size_t count_below_thr = 0;
|
||||
size_t count_at_min = 0;
|
||||
size_t count_at_max = 0;
|
||||
size_t count_at_ins = 0;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -527,6 +527,8 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
|
|||
|
||||
// consecutive accept rounds with low acceptance fraction (< 0.5)
|
||||
int n_low = 0;
|
||||
// hash indices of ngrams consulted during the most recent draft
|
||||
std::vector<size_t> used_hashes;
|
||||
|
||||
// enable trace logging if LLAMA_TRACE is set
|
||||
const bool verbose;
|
||||
|
|
@ -558,7 +560,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
|
|||
|
||||
constexpr double f_thold = 0.25;
|
||||
if (f > f_thold) {
|
||||
LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold);
|
||||
LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting (collisions=%zu)\n", __func__, f, f_thold, mod.get_collisions());
|
||||
|
||||
mod.reset();
|
||||
}
|
||||
|
|
@ -572,6 +574,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
|
|||
GGML_UNUSED(params);
|
||||
|
||||
n_draft_last = 0;
|
||||
used_hashes.clear();
|
||||
|
||||
const size_t cur_len = prompt_tgt.size();
|
||||
if (cur_len < mod.get_n()) {
|
||||
|
|
@ -607,6 +610,8 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
|
|||
break;
|
||||
}
|
||||
result[n + i] = token;
|
||||
// remember which hash entry produced this token
|
||||
used_hashes.push_back(mod.index(result.data() + i));
|
||||
}
|
||||
|
||||
// only return the m tokens that were drafted
|
||||
|
|
@ -627,18 +632,37 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
|
|||
// compute acceptance fraction if we have a recorded draft length
|
||||
if (n_draft_last > 0) {
|
||||
const double f_acc = (double)n_accepted / (double)n_draft_last;
|
||||
|
||||
// update per-ngram scores based on acceptance outcome
|
||||
for (size_t i = 0; i < n_draft_last; ++i) {
|
||||
if (i < static_cast<size_t>(n_accepted)) {
|
||||
mod.inc_score_by_index(used_hashes[i]);
|
||||
} else {
|
||||
mod.dec_score_by_index(used_hashes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
if (f_acc < 0.5) {
|
||||
n_low++;
|
||||
if (n_low >= 3) {
|
||||
LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low);
|
||||
LOG_WRN("%s: low acceptance streak (%d) - pruning ngram_mod (collisions=%zu)\n", __func__, n_low, mod.get_collisions());
|
||||
// Log detailed score metrics before pruning
|
||||
mod.update_score_stats();
|
||||
LOG_WRN("%s: before prune scores - below_thr=%zu, at_min=%zu, at_max=%zu, at_ins=%zu\n",
|
||||
__func__,
|
||||
mod.get_below_thr(),
|
||||
mod.get_at_min(),
|
||||
mod.get_at_max(),
|
||||
mod.get_at_ins());
|
||||
|
||||
mod.reset();
|
||||
mod.prune_low_score();
|
||||
n_low = 0;
|
||||
}
|
||||
} else {
|
||||
n_low = 0;
|
||||
}
|
||||
}
|
||||
used_hashes.clear();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue