74 lines
1.8 KiB
C++
74 lines
1.8 KiB
C++
#pragma once
|
|
|
|
#include <cstdint>
|
|
#include <vector>
|
|
#include <cstddef>
|
|
|
|
//
|
|
// common_ngram_mod
|
|
// ref: https://github.com/ggml-org/llama.cpp/pull/19164
|
|
//
|
|
|
|
// basic n-gram hasher
|
|
struct common_ngram_mod {
|
|
using entry_t = int32_t;
|
|
|
|
static constexpr entry_t EMPTY = -1;
|
|
|
|
static constexpr int8_t SCORE_INIT = 0;
|
|
static constexpr int8_t SCORE_MIN = -5;
|
|
static constexpr int8_t SCORE_MAX = 20;
|
|
static constexpr int8_t SCORE_THR = 0;
|
|
static constexpr int8_t SCORE_INS = 3;
|
|
|
|
common_ngram_mod(uint16_t n, size_t size);
|
|
|
|
size_t idx(const entry_t * tokens) const;
|
|
void add(const entry_t * tokens);
|
|
entry_t get(const entry_t * tokens) const; // return -1 if not found
|
|
|
|
void reset();
|
|
|
|
// expose the hash index for external bookkeeping
|
|
size_t index(const entry_t * tokens) const;
|
|
|
|
// score handling
|
|
void inc_score(const entry_t * tokens);
|
|
void dec_score(const entry_t * tokens);
|
|
void inc_score_by_index(size_t i);
|
|
void dec_score_by_index(size_t i);
|
|
void prune_low_score(); // remove entries below SCORE_THR
|
|
|
|
size_t get_n() const;
|
|
size_t get_used() const;
|
|
|
|
void update_score_stats();
|
|
|
|
size_t get_collisions() const;
|
|
size_t get_below_thr() const;
|
|
size_t get_at_min() const;
|
|
size_t get_at_max() const;
|
|
size_t get_at_ins() const;
|
|
|
|
size_t size() const;
|
|
size_t size_bytes() const;
|
|
|
|
private:
|
|
size_t n; // ngram size to hash
|
|
|
|
size_t used;
|
|
|
|
std::vector<entry_t> entries;
|
|
// per-entry score, range SCORE_MIN .. SCORE_MAX
|
|
std::vector<int8_t> scores;
|
|
|
|
// stats
|
|
// count of hash collisions
|
|
size_t collisions = 0;
|
|
// counts for score
|
|
size_t count_below_thr = 0;
|
|
size_t count_at_min = 0;
|
|
size_t count_at_max = 0;
|
|
size_t count_at_ins = 0;
|
|
};
|