llama : partially apply clang-format style
This commit is contained in:
parent
691698e152
commit
e3fe61203c
|
|
@ -2583,7 +2583,7 @@ struct llama_hparams {
|
|||
return n_embd_head_v * n_head_kv;
|
||||
}
|
||||
|
||||
uint32_t n_embd_r(uint32_t il) const { // dimension of the rolling state embeddings
|
||||
uint32_t n_embd_r(uint32_t il) const { // dimension of the rolling state embeddings
|
||||
// TODO: support using an SSM in place of the MLP of a Transformer
|
||||
if (n_head_kv(il) != 0) { return 0; }
|
||||
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
||||
|
|
@ -2597,7 +2597,7 @@ struct llama_hparams {
|
|||
}
|
||||
}
|
||||
|
||||
uint32_t n_embd_s(uint32_t il) const { // dimension of the recurrent state embeddings
|
||||
uint32_t n_embd_s(uint32_t il) const { // dimension of the recurrent state embeddings
|
||||
// TODO: support using an SSM in place of the MLP of a Transformer
|
||||
if (n_head_kv(il) != 0) { return 0; }
|
||||
|
||||
|
|
@ -2875,17 +2875,13 @@ struct llama_kv_self_cache {
|
|||
|
||||
struct llama_rs_cell {
|
||||
llama_pos pos = -1;
|
||||
int32_t src = -1; // copy source id (cleared next when -1)
|
||||
int32_t src = -1; // copy source id (cleared next when -1)
|
||||
|
||||
std::set<llama_seq_id> seq_id;
|
||||
|
||||
bool has_seq_id(const llama_seq_id & id) const {
|
||||
return seq_id.find(id) != seq_id.end();
|
||||
}
|
||||
bool has_seq_id(const llama_seq_id & id) const { return seq_id.find(id) != seq_id.end(); }
|
||||
|
||||
bool is_empty() const {
|
||||
return seq_id.empty();
|
||||
}
|
||||
bool is_empty() const { return seq_id.empty(); }
|
||||
};
|
||||
|
||||
struct llama_rs_seq_meta {
|
||||
|
|
@ -2895,24 +2891,23 @@ struct llama_rs_seq_meta {
|
|||
|
||||
// ring-buffered tree of cached recurrent state data
|
||||
struct llama_rs_self_cache {
|
||||
|
||||
uint32_t head = 0; // first state used for the last slot
|
||||
uint32_t head = 0; // first state used for the last slot
|
||||
uint32_t size = 0;
|
||||
uint32_t used = 0;
|
||||
|
||||
// computed when finding a slot
|
||||
uint32_t n = 0; // range of states used for the last slot
|
||||
uint32_t n = 0; // range of states used for the last slot
|
||||
|
||||
// with state models, a cell can hold the state for more than one past token
|
||||
std::vector<llama_rs_cell> cells;
|
||||
|
||||
// find tail cells faster
|
||||
std::vector<llama_rs_seq_meta> seq_tails; // map seq_ids to cell ids
|
||||
std::vector<llama_rs_seq_meta> seq_tails; // map seq_ids to cell ids
|
||||
|
||||
// per layer
|
||||
// NOTE: the naming of r and s is arbitrary
|
||||
std::vector<struct ggml_tensor *> r_l; // rolling/shift states
|
||||
std::vector<struct ggml_tensor *> s_l; // ssm (recurrent) states
|
||||
std::vector<struct ggml_tensor *> r_l; // rolling/shift states
|
||||
std::vector<struct ggml_tensor *> s_l; // ssm (recurrent) states
|
||||
|
||||
// Inefficient, but thorough verification and rebuilding of the rs cache
|
||||
// from only the cells list with `pos` and seq_ids.
|
||||
|
|
@ -2920,21 +2915,21 @@ struct llama_rs_self_cache {
|
|||
bool rebuild(bool debug) {
|
||||
bool was_valid = true;
|
||||
// skip for non-recurrent models
|
||||
if (size == 0) { return true; }
|
||||
if (size == 0) {
|
||||
return true;
|
||||
}
|
||||
// the source of truth is the cells list
|
||||
// buffer sizes
|
||||
if (size != cells.size()) {
|
||||
if (debug) {
|
||||
LLAMA_LOG_ERROR("%s: cells has wrong size (%zu instead of %u)\n",
|
||||
__func__, cells.size(), size);
|
||||
LLAMA_LOG_ERROR("%s: cells has wrong size (%zu instead of %u)\n", __func__, cells.size(), size);
|
||||
}
|
||||
cells.resize(size);
|
||||
was_valid = false;
|
||||
}
|
||||
if (size != seq_tails.size()) {
|
||||
if (debug) {
|
||||
LLAMA_LOG_ERROR("%s: seq_tails has wrong size (%zu instead of %u)\n",
|
||||
__func__, seq_tails.size(), size);
|
||||
LLAMA_LOG_ERROR("%s: seq_tails has wrong size (%zu instead of %u)\n", __func__, seq_tails.size(), size);
|
||||
}
|
||||
seq_tails.resize(size);
|
||||
was_valid = false;
|
||||
|
|
@ -2994,7 +2989,7 @@ struct llama_rs_self_cache {
|
|||
for (uint32_t cell_id = 0; cell_id < size; ++cell_id) {
|
||||
llama_rs_cell & cell = cells[cell_id];
|
||||
if (cell.has_seq_id(seq_id)) {
|
||||
seq_cells.push_back({cell.pos, cell_id});
|
||||
seq_cells.push_back({ cell.pos, cell_id });
|
||||
}
|
||||
}
|
||||
// sort by pos and then by cell_id
|
||||
|
|
@ -3718,16 +3713,16 @@ static bool llama_kv_cache_init(
|
|||
}
|
||||
|
||||
if (has_kv) {
|
||||
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_gqa(i)*kv_size);
|
||||
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_gqa(i)*kv_size);
|
||||
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_gqa(i) * kv_size);
|
||||
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_gqa(i) * kv_size);
|
||||
ggml_format_name(k, "cache_k_l%d", i);
|
||||
ggml_format_name(v, "cache_v_l%d", i);
|
||||
cache.kv.k_l.push_back(k);
|
||||
cache.kv.v_l.push_back(v);
|
||||
}
|
||||
if (has_rs) {
|
||||
ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_r(i)*rs_size);
|
||||
ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_s(i)*rs_size);
|
||||
ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_r(i) * rs_size);
|
||||
ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_s(i) * rs_size);
|
||||
ggml_format_name(r, "cache_r_l%d", i);
|
||||
ggml_format_name(s, "cache_s_l%d", i);
|
||||
cache.rs.r_l.push_back(r);
|
||||
|
|
@ -4370,8 +4365,8 @@ struct llama_kv_slot_restorer {
|
|||
bool do_restore = false;
|
||||
|
||||
explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
|
||||
old_state.head = cache.kv.head;
|
||||
old_state.n = cache.kv.n;
|
||||
old_state.head = cache.kv.head;
|
||||
old_state.n = cache.kv.n;
|
||||
}
|
||||
|
||||
// saves a slot information for future restoration
|
||||
|
|
@ -4388,10 +4383,10 @@ struct llama_kv_slot_restorer {
|
|||
// and rollback changes from all llama_kv_cache_find_slot calls
|
||||
void restore(struct llama_kv_cache & cache) {
|
||||
if (do_restore) {
|
||||
cache.kv.head = old_state.head;
|
||||
cache.kv.n = old_state.n;
|
||||
cache.kv.head = old_state.head;
|
||||
cache.kv.n = old_state.n;
|
||||
|
||||
if (cache.rs.size > 0) { // recurrent models like Mamba or RWKV can't have a state partially erased
|
||||
if (cache.rs.size > 0) { // recurrent models like Mamba or RWKV can't have a state partially erased
|
||||
llama_kv_cache_seq_rm(cache, -1, -1, -1);
|
||||
} else {
|
||||
for (auto & slot : slot_boundaries) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue