llama : partially apply clang-format style

This commit is contained in:
Francis Couture-Harpin 2024-11-25 11:31:46 -05:00
parent 691698e152
commit e3fe61203c
1 changed files with 25 additions and 30 deletions

View File

@ -2583,7 +2583,7 @@ struct llama_hparams {
return n_embd_head_v * n_head_kv;
}
uint32_t n_embd_r(uint32_t il) const { // dimension of the rolling state embeddings
uint32_t n_embd_r(uint32_t il) const { // dimension of the rolling state embeddings
// TODO: support using an SSM in place of the MLP of a Transformer
if (n_head_kv(il) != 0) { return 0; }
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
@ -2597,7 +2597,7 @@ struct llama_hparams {
}
}
uint32_t n_embd_s(uint32_t il) const { // dimension of the recurrent state embeddings
uint32_t n_embd_s(uint32_t il) const { // dimension of the recurrent state embeddings
// TODO: support using an SSM in place of the MLP of a Transformer
if (n_head_kv(il) != 0) { return 0; }
@ -2875,17 +2875,13 @@ struct llama_kv_self_cache {
struct llama_rs_cell {
llama_pos pos = -1;
int32_t src = -1; // copy source id (cleared next when -1)
int32_t src = -1; // copy source id (cleared next when -1)
std::set<llama_seq_id> seq_id;
bool has_seq_id(const llama_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}
bool has_seq_id(const llama_seq_id & id) const { return seq_id.find(id) != seq_id.end(); }
bool is_empty() const {
return seq_id.empty();
}
bool is_empty() const { return seq_id.empty(); }
};
struct llama_rs_seq_meta {
@ -2895,24 +2891,23 @@ struct llama_rs_seq_meta {
// ring-buffered tree of cached recurrent state data
struct llama_rs_self_cache {
uint32_t head = 0; // first state used for the last slot
uint32_t head = 0; // first state used for the last slot
uint32_t size = 0;
uint32_t used = 0;
// computed when finding a slot
uint32_t n = 0; // range of states used for the last slot
uint32_t n = 0; // range of states used for the last slot
// with state models, a cell can hold the state for more than one past token
std::vector<llama_rs_cell> cells;
// find tail cells faster
std::vector<llama_rs_seq_meta> seq_tails; // map seq_ids to cell ids
std::vector<llama_rs_seq_meta> seq_tails; // map seq_ids to cell ids
// per layer
// NOTE: the naming of r and s is arbitrary
std::vector<struct ggml_tensor *> r_l; // rolling/shift states
std::vector<struct ggml_tensor *> s_l; // ssm (recurrent) states
std::vector<struct ggml_tensor *> r_l; // rolling/shift states
std::vector<struct ggml_tensor *> s_l; // ssm (recurrent) states
// Inefficient, but thorough verification and rebuilding of the rs cache
// from only the cells list with `pos` and seq_ids.
@ -2920,21 +2915,21 @@ struct llama_rs_self_cache {
bool rebuild(bool debug) {
bool was_valid = true;
// skip for non-recurrent models
if (size == 0) { return true; }
if (size == 0) {
return true;
}
// the source of truth is the cells list
// buffer sizes
if (size != cells.size()) {
if (debug) {
LLAMA_LOG_ERROR("%s: cells has wrong size (%zu instead of %u)\n",
__func__, cells.size(), size);
LLAMA_LOG_ERROR("%s: cells has wrong size (%zu instead of %u)\n", __func__, cells.size(), size);
}
cells.resize(size);
was_valid = false;
}
if (size != seq_tails.size()) {
if (debug) {
LLAMA_LOG_ERROR("%s: seq_tails has wrong size (%zu instead of %u)\n",
__func__, seq_tails.size(), size);
LLAMA_LOG_ERROR("%s: seq_tails has wrong size (%zu instead of %u)\n", __func__, seq_tails.size(), size);
}
seq_tails.resize(size);
was_valid = false;
@ -2994,7 +2989,7 @@ struct llama_rs_self_cache {
for (uint32_t cell_id = 0; cell_id < size; ++cell_id) {
llama_rs_cell & cell = cells[cell_id];
if (cell.has_seq_id(seq_id)) {
seq_cells.push_back({cell.pos, cell_id});
seq_cells.push_back({ cell.pos, cell_id });
}
}
// sort by pos and then by cell_id
@ -3718,16 +3713,16 @@ static bool llama_kv_cache_init(
}
if (has_kv) {
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_gqa(i)*kv_size);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_gqa(i)*kv_size);
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_gqa(i) * kv_size);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_gqa(i) * kv_size);
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
cache.kv.k_l.push_back(k);
cache.kv.v_l.push_back(v);
}
if (has_rs) {
ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_r(i)*rs_size);
ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_s(i)*rs_size);
ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_r(i) * rs_size);
ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_s(i) * rs_size);
ggml_format_name(r, "cache_r_l%d", i);
ggml_format_name(s, "cache_s_l%d", i);
cache.rs.r_l.push_back(r);
@ -4370,8 +4365,8 @@ struct llama_kv_slot_restorer {
bool do_restore = false;
explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
old_state.head = cache.kv.head;
old_state.n = cache.kv.n;
old_state.head = cache.kv.head;
old_state.n = cache.kv.n;
}
// saves a slot information for future restoration
@ -4388,10 +4383,10 @@ struct llama_kv_slot_restorer {
// and rollback changes from all llama_kv_cache_find_slot calls
void restore(struct llama_kv_cache & cache) {
if (do_restore) {
cache.kv.head = old_state.head;
cache.kv.n = old_state.n;
cache.kv.head = old_state.head;
cache.kv.n = old_state.n;
if (cache.rs.size > 0) { // recurrent models like Mamba or RWKV can't have a state partially erased
if (cache.rs.size > 0) { // recurrent models like Mamba or RWKV can't have a state partially erased
llama_kv_cache_seq_rm(cache, -1, -1, -1);
} else {
for (auto & slot : slot_boundaries) {