working on swa with local and global alternating attention

This commit is contained in:
ryan-mangeno 2025-09-07 21:00:38 -04:00
parent 39c029144b
commit e101005d1a
3 changed files with 113 additions and 19 deletions

View File

@ -171,6 +171,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
{ LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" },
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },

View File

@ -176,6 +176,7 @@ enum llm_kv {
LLM_KV_ROPE_DIMENSION_SECTIONS,
LLM_KV_ROPE_FREQ_BASE,
LLM_KV_ROPE_SCALE_LINEAR,
LLM_KV_ROPE_FREQ_BASE_SWA,
LLM_KV_ROPE_SCALING_TYPE,
LLM_KV_ROPE_SCALING_FACTOR,
LLM_KV_ROPE_SCALING_ATTN_FACTOR,

View File

@ -7559,6 +7559,7 @@ struct llm_build_modern_bert : public llm_graph_context {
const int64_t n_head_kv = hparams.n_head_kv();
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
const int64_t n_local_swa = hparams.n_swa;
const int64_t n_tokens = ubatch.n_tokens;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@ -7574,7 +7575,17 @@ struct llm_build_modern_bert : public llm_graph_context {
const float beta_fast = 0.0f;
const float beta_slow = 0.0f;
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor *inp_pos_global = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, 4096, 1);
ggml_set_input(inp_pos_global);
size_t element_size = ggml_type_size(inp_pos_global->type);
size_t nb1 = element_size;
size_t nb2 = nb1;
inp_pos_global = ggml_view_3d(ctx0, inp_pos_global, 1, 1, 4096, nb1, nb2, 0);
inp_pos_global = ggml_cont(ctx0, inp_pos_global);
ggml_tensor * inpL = build_inp_embd(model.tok_embd);
if (model.type_embd) {
@ -7582,19 +7593,20 @@ struct llm_build_modern_bert : public llm_graph_context {
}
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
auto * inp_attn = build_attn_inp_no_cache();
auto * inp_attn = build_attn_inp_kv_unified_iswa();
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * x = inpL;
// Pre attention Layer norm
// pre attn LayerNorm
ggml_tensor * x_attn_in = x;
if (model.layers[il].attn_norm) {
x_attn_in = build_norm(x, model.layers[il].attn_norm, model.layers[il].attn_norm_b, LLM_NORM, il);
}
// fused qkv
// fused QKV
GGML_ASSERT(model.layers[il].wqkv);
ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, x_attn_in);
if (model.layers[il].bqkv) {
@ -7609,41 +7621,120 @@ struct llm_build_modern_bert : public llm_graph_context {
if (model.layers[il].attn_q_norm) Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il);
if (model.layers[il].attn_k_norm) Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il);
// reshape for multi head
// reshape for multi-head attention
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
// rope embedding
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
// global or local layer
bool is_global = ((il + 1) % 3 == 0);
float freq_base_l = is_global ? 160000.0f : 10000.0f; // rope theta
float freq_scale_l = 1.0f;
ggml_tensor * pos_q = inp_pos_global;
ggml_tensor * K_work = Kcur;
ggml_tensor * V_work = Vcur;
ggml_tensor * pos_k = inp_pos_global;
if (!is_global) {
ggml_tensor * idx_src = inp_attn->self_k_idxs_swa;
ggml_tensor * idx_view1d = ggml_view_1d(ctx0, idx_src, idx_src->ne[0], 0);
ggml_tensor * idx_cont = ggml_cont(ctx0, idx_view1d);
ggml_tensor * idx_i32 = idx_cont;
if (idx_i32->type != GGML_TYPE_I32) {
idx_i32 = ggml_cast(ctx0, idx_cont, GGML_TYPE_I32);
}
const int64_t n_indices = idx_i32->ne[0];
ggml_tensor * idx_2d = ggml_view_2d(ctx0, idx_i32, 1, n_indices, sizeof(int32_t), 0);
idx_2d = ggml_cont(ctx0, idx_2d);
if (idx_2d->type != GGML_TYPE_I32) idx_2d = ggml_cast(ctx0, idx_2d, GGML_TYPE_I32);
Kcur->ne[0], Kcur->ne[1], Kcur->ne[2],
idx_2d->ne[0], idx_2d->ne[1], idx_2d->ne[2], idx_2d->ne[3],
idx_2d->type);
K_work = ggml_get_rows(ctx0, Kcur, idx_2d);
V_work = ggml_get_rows(ctx0, Vcur, idx_2d);
ggml_tensor * pos_rows = ggml_get_rows(ctx0, inp_pos_global, idx_2d);
if (!ggml_is_vector(pos_rows)) {
const int64_t n_el = ggml_nelements(pos_rows);
pos_rows = ggml_view_1d(ctx0, pos_rows, n_el, 0);
pos_rows = ggml_cont(ctx0, pos_rows);
} else {
pos_rows = ggml_cont(ctx0, pos_rows);
}
// ensure I32
if (pos_rows->type != GGML_TYPE_I32) {
pos_rows = ggml_cast(ctx0, pos_rows, GGML_TYPE_I32);
}
// final pos_k to pass to rope
pos_k = pos_rows;
LLAMA_LOG_INFO("pos_k final: ne[0]=%lld, type=%d\n", pos_k->ne[0], pos_k->type);
}
if( !ggml_is_vector(pos_q) ) {
const int64_t n_el = ggml_nelements(pos_q);
pos_q = ggml_view_1d(ctx0, pos_q, n_el, 0);
pos_q = ggml_cont(ctx0, pos_q);
}
if( !ggml_is_vector(pos_q) ) {
}
// apply rope
Qcur = ggml_rope_ext(ctx0, Qcur, pos_q, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
if( !ggml_is_vector(pos_k) ) {
const int64_t n_el = ggml_nelements(pos_k);
pos_k = ggml_view_1d(ctx0, pos_k, n_el, 0);
pos_k = ggml_cont(ctx0, pos_k);
}
K_work = ggml_rope_ext(ctx0, K_work, pos_k, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
// choseing mask, global vs swa
ggml_tensor * kq_b_layer = is_global ? inp_attn->self_kq_mask : inp_attn->self_kq_mask_swa;
ggml_tensor * attn_out = build_attn(
inp_attn,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur,
/*k cache*/ nullptr,
/*v cache*/ nullptr,
model.layers[il].wo,
model.layers[il].bo,
Qcur,
K_work,
V_work,
kq_b_layer,
nullptr,
1.0f / sqrtf(float(n_embd_head)),
il
);
// residual addition
ggml_tensor * cur_attn = ggml_add(ctx0, attn_out, x);
// optional subselect output tokens (inp_out_ids)
// optional output select
if (il == n_layer - 1 && inp_out_ids) {
cur_attn = ggml_get_rows(ctx0, cur_attn, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
// pre mlp LayerNorm
// pre mlp layer norm
ggml_tensor * h = build_norm(cur_attn, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, il);
// geglu FFN
// geglu ffn
ggml_tensor * mlp_out = build_ffn(
h,
model.layers[il].ffn_up, NULL, NULL,
@ -7653,10 +7744,11 @@ struct llm_build_modern_bert : public llm_graph_context {
LLM_FFN_GEGLU, LLM_FFN_PAR, il
);
// resid addition
// resudi addition after FFN
inpL = ggml_add(ctx0, mlp_out, cur_attn);
}
ggml_tensor * cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1);
res->t_embd = cur;
ggml_build_forward_expand(gf, cur);