initial ngram-mod proof of concept, score-based pruning
This commit is contained in:
parent
1f1e57f2bf
commit
e543f88952
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
common_ngram_mod::common_ngram_mod(uint16_t n, size_t size) : n(n), used(0) {
|
common_ngram_mod::common_ngram_mod(uint16_t n, size_t size) : n(n), used(0) {
|
||||||
entries.resize(size);
|
entries.resize(size);
|
||||||
|
scores.resize(size, SCORE_INIT);
|
||||||
|
|
||||||
reset();
|
reset();
|
||||||
}
|
}
|
||||||
|
|
@ -27,8 +28,12 @@ void common_ngram_mod::add(const entry_t * tokens) {
|
||||||
|
|
||||||
if (entries[i] == EMPTY) {
|
if (entries[i] == EMPTY) {
|
||||||
used++;
|
used++;
|
||||||
|
scores[i] = SCORE_INS;
|
||||||
|
} else if (entries[i] != tokens[n]) {
|
||||||
|
// a different token hashes to the same bucket
|
||||||
|
++collisions;
|
||||||
}
|
}
|
||||||
|
// keep existing score if entry already occupied
|
||||||
entries[i] = tokens[n];
|
entries[i] = tokens[n];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -40,7 +45,9 @@ common_ngram_mod::entry_t common_ngram_mod::get(const entry_t * tokens) const {
|
||||||
|
|
||||||
void common_ngram_mod::reset() {
|
void common_ngram_mod::reset() {
|
||||||
std::fill(entries.begin(), entries.end(), EMPTY);
|
std::fill(entries.begin(), entries.end(), EMPTY);
|
||||||
|
std::fill(scores.begin(), scores.end(), 0);
|
||||||
used = 0;
|
used = 0;
|
||||||
|
collisions = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t common_ngram_mod::get_n() const {
|
size_t common_ngram_mod::get_n() const {
|
||||||
|
|
@ -56,5 +63,83 @@ size_t common_ngram_mod::size() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t common_ngram_mod::size_bytes() const {
|
size_t common_ngram_mod::size_bytes() const {
|
||||||
return entries.size() * sizeof(entries[0]);
|
return entries.size() * sizeof(entries[0]) + scores.size() * sizeof(scores[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t common_ngram_mod::index(const entry_t * tokens) const {
|
||||||
|
return idx(tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_ngram_mod::inc_score(const entry_t * tokens) {
|
||||||
|
const size_t i = idx(tokens);
|
||||||
|
if (scores[i] < common_ngram_mod::SCORE_MAX) {
|
||||||
|
++scores[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_ngram_mod::dec_score(const entry_t * tokens) {
|
||||||
|
const size_t i = idx(tokens);
|
||||||
|
if (scores[i] > common_ngram_mod::SCORE_MIN) {
|
||||||
|
--scores[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_ngram_mod::inc_score_by_index(size_t i) {
|
||||||
|
if (i < scores.size() && scores[i] < common_ngram_mod::SCORE_MAX) {
|
||||||
|
++scores[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_ngram_mod::dec_score_by_index(size_t i) {
|
||||||
|
if (i < scores.size() && scores[i] > common_ngram_mod::SCORE_MIN) {
|
||||||
|
--scores[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_ngram_mod::prune_low_score() {
|
||||||
|
used = 0;
|
||||||
|
for (size_t i = 0; i < entries.size(); ++i) {
|
||||||
|
if (scores[i] < common_ngram_mod::SCORE_THR) {
|
||||||
|
entries[i] = EMPTY;
|
||||||
|
scores[i] = 0;
|
||||||
|
} else {
|
||||||
|
++used;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t common_ngram_mod::get_collisions() const {
|
||||||
|
return collisions;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t common_ngram_mod::get_below_thr() const {
|
||||||
|
return count_below_thr;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t common_ngram_mod::get_at_min() const {
|
||||||
|
return count_at_min;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t common_ngram_mod::get_at_max() const {
|
||||||
|
return count_at_max;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t common_ngram_mod::get_at_ins() const {
|
||||||
|
return count_at_ins;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_ngram_mod::update_score_stats() {
|
||||||
|
// reset counters
|
||||||
|
count_below_thr = 0;
|
||||||
|
count_at_min = 0;
|
||||||
|
count_at_max = 0;
|
||||||
|
count_at_ins = 0;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < scores.size(); ++i) {
|
||||||
|
const int8_t s = scores[i];
|
||||||
|
if (s < SCORE_THR) ++count_below_thr;
|
||||||
|
if (s == SCORE_MIN) ++count_at_min;
|
||||||
|
if (s == SCORE_MAX) ++count_at_max;
|
||||||
|
if (s == SCORE_INS) ++count_at_ins;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,12 @@ struct common_ngram_mod {
|
||||||
|
|
||||||
static constexpr entry_t EMPTY = -1;
|
static constexpr entry_t EMPTY = -1;
|
||||||
|
|
||||||
|
static constexpr int8_t SCORE_INIT = 0;
|
||||||
|
static constexpr int8_t SCORE_MIN = -5;
|
||||||
|
static constexpr int8_t SCORE_MAX = 20;
|
||||||
|
static constexpr int8_t SCORE_THR = 0; // keep equal or lower than SCORE_INIT
|
||||||
|
static constexpr int8_t SCORE_INS = 3;
|
||||||
|
|
||||||
common_ngram_mod(uint16_t n, size_t size);
|
common_ngram_mod(uint16_t n, size_t size);
|
||||||
|
|
||||||
size_t idx(const entry_t * tokens) const;
|
size_t idx(const entry_t * tokens) const;
|
||||||
|
|
@ -23,9 +29,27 @@ struct common_ngram_mod {
|
||||||
|
|
||||||
void reset();
|
void reset();
|
||||||
|
|
||||||
|
// expose the hash index for external bookkeeping
|
||||||
|
size_t index(const entry_t * tokens) const;
|
||||||
|
|
||||||
|
// score handling
|
||||||
|
void inc_score(const entry_t * tokens);
|
||||||
|
void dec_score(const entry_t * tokens);
|
||||||
|
void inc_score_by_index(size_t i);
|
||||||
|
void dec_score_by_index(size_t i);
|
||||||
|
void prune_low_score(); // remove entries below SCORE_THR
|
||||||
|
|
||||||
size_t get_n() const;
|
size_t get_n() const;
|
||||||
size_t get_used() const;
|
size_t get_used() const;
|
||||||
|
|
||||||
|
void update_score_stats();
|
||||||
|
|
||||||
|
size_t get_collisions() const;
|
||||||
|
size_t get_below_thr() const;
|
||||||
|
size_t get_at_min() const;
|
||||||
|
size_t get_at_max() const;
|
||||||
|
size_t get_at_ins() const;
|
||||||
|
|
||||||
size_t size() const;
|
size_t size() const;
|
||||||
size_t size_bytes() const;
|
size_t size_bytes() const;
|
||||||
|
|
||||||
|
|
@ -35,4 +59,15 @@ private:
|
||||||
size_t used;
|
size_t used;
|
||||||
|
|
||||||
std::vector<entry_t> entries;
|
std::vector<entry_t> entries;
|
||||||
|
// per-entry score, range SCORE_MIN .. SCORE_MAX
|
||||||
|
std::vector<int8_t> scores;
|
||||||
|
|
||||||
|
// stats
|
||||||
|
// count of hash collisions
|
||||||
|
size_t collisions = 0;
|
||||||
|
// counts for score
|
||||||
|
size_t count_below_thr = 0;
|
||||||
|
size_t count_at_min = 0;
|
||||||
|
size_t count_at_max = 0;
|
||||||
|
size_t count_at_ins = 0;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -527,6 +527,8 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
|
||||||
|
|
||||||
// consecutive accept rounds with low acceptance fraction (< 0.5)
|
// consecutive accept rounds with low acceptance fraction (< 0.5)
|
||||||
int n_low = 0;
|
int n_low = 0;
|
||||||
|
// hash indices of ngrams consulted during the most recent draft
|
||||||
|
std::vector<size_t> used_hashes;
|
||||||
|
|
||||||
// enable trace logging if LLAMA_TRACE is set
|
// enable trace logging if LLAMA_TRACE is set
|
||||||
const bool verbose;
|
const bool verbose;
|
||||||
|
|
@ -558,7 +560,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
|
||||||
|
|
||||||
constexpr double f_thold = 0.25;
|
constexpr double f_thold = 0.25;
|
||||||
if (f > f_thold) {
|
if (f > f_thold) {
|
||||||
LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold);
|
LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting (collisions=%zu)\n", __func__, f, f_thold, mod.get_collisions());
|
||||||
|
|
||||||
mod.reset();
|
mod.reset();
|
||||||
}
|
}
|
||||||
|
|
@ -572,6 +574,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
|
||||||
GGML_UNUSED(params);
|
GGML_UNUSED(params);
|
||||||
|
|
||||||
n_draft_last = 0;
|
n_draft_last = 0;
|
||||||
|
used_hashes.clear();
|
||||||
|
|
||||||
const size_t cur_len = prompt_tgt.size();
|
const size_t cur_len = prompt_tgt.size();
|
||||||
if (cur_len < mod.get_n()) {
|
if (cur_len < mod.get_n()) {
|
||||||
|
|
@ -607,6 +610,8 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
result[n + i] = token;
|
result[n + i] = token;
|
||||||
|
// remember which hash entry produced this token
|
||||||
|
used_hashes.push_back(mod.index(result.data() + i));
|
||||||
}
|
}
|
||||||
|
|
||||||
// only return the m tokens that were drafted
|
// only return the m tokens that were drafted
|
||||||
|
|
@ -627,18 +632,37 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
|
||||||
// compute acceptance fraction if we have a recorded draft length
|
// compute acceptance fraction if we have a recorded draft length
|
||||||
if (n_draft_last > 0) {
|
if (n_draft_last > 0) {
|
||||||
const double f_acc = (double)n_accepted / (double)n_draft_last;
|
const double f_acc = (double)n_accepted / (double)n_draft_last;
|
||||||
|
|
||||||
|
// update per-ngram scores based on acceptance outcome
|
||||||
|
for (size_t i = 0; i < n_draft_last; ++i) {
|
||||||
|
if (i < static_cast<size_t>(n_accepted)) {
|
||||||
|
mod.inc_score_by_index(used_hashes[i]);
|
||||||
|
} else {
|
||||||
|
mod.dec_score_by_index(used_hashes[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (f_acc < 0.5) {
|
if (f_acc < 0.5) {
|
||||||
n_low++;
|
n_low++;
|
||||||
if (n_low >= 3) {
|
if (n_low >= 3) {
|
||||||
LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low);
|
LOG_WRN("%s: low acceptance streak (%d) - pruning ngram_mod (collisions=%zu)\n", __func__, n_low, mod.get_collisions());
|
||||||
|
// Log detailed score metrics before pruning
|
||||||
|
mod.update_score_stats();
|
||||||
|
LOG_WRN("%s: before prune scores - below_thr=%zu, at_min=%zu, at_max=%zu, at_ins=%zu\n",
|
||||||
|
__func__,
|
||||||
|
mod.get_below_thr(),
|
||||||
|
mod.get_at_min(),
|
||||||
|
mod.get_at_max(),
|
||||||
|
mod.get_at_ins());
|
||||||
|
|
||||||
mod.reset();
|
mod.prune_low_score();
|
||||||
n_low = 0;
|
n_low = 0;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
n_low = 0;
|
n_low = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
used_hashes.clear();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue