working on swa with local and global alternating attention
This commit is contained in:
parent
39c029144b
commit
e101005d1a
|
|
@ -171,6 +171,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
||||
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
|
||||
{ LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" },
|
||||
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
|
||||
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
|
||||
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
|
||||
|
|
|
|||
|
|
@ -176,6 +176,7 @@ enum llm_kv {
|
|||
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
||||
LLM_KV_ROPE_FREQ_BASE,
|
||||
LLM_KV_ROPE_SCALE_LINEAR,
|
||||
LLM_KV_ROPE_FREQ_BASE_SWA,
|
||||
LLM_KV_ROPE_SCALING_TYPE,
|
||||
LLM_KV_ROPE_SCALING_FACTOR,
|
||||
LLM_KV_ROPE_SCALING_ATTN_FACTOR,
|
||||
|
|
|
|||
|
|
@ -7559,6 +7559,7 @@ struct llm_build_modern_bert : public llm_graph_context {
|
|||
const int64_t n_head_kv = hparams.n_head_kv();
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
const int64_t n_local_swa = hparams.n_swa;
|
||||
const int64_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
|
@ -7574,7 +7575,17 @@ struct llm_build_modern_bert : public llm_graph_context {
|
|||
const float beta_fast = 0.0f;
|
||||
const float beta_slow = 0.0f;
|
||||
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
ggml_tensor *inp_pos_global = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, 4096, 1);
|
||||
ggml_set_input(inp_pos_global);
|
||||
size_t element_size = ggml_type_size(inp_pos_global->type);
|
||||
|
||||
size_t nb1 = element_size;
|
||||
size_t nb2 = nb1;
|
||||
|
||||
inp_pos_global = ggml_view_3d(ctx0, inp_pos_global, 1, 1, 4096, nb1, nb2, 0);
|
||||
inp_pos_global = ggml_cont(ctx0, inp_pos_global);
|
||||
|
||||
ggml_tensor * inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
if (model.type_embd) {
|
||||
|
|
@ -7582,19 +7593,20 @@ struct llm_build_modern_bert : public llm_graph_context {
|
|||
}
|
||||
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
||||
|
||||
auto * inp_attn = build_attn_inp_no_cache();
|
||||
auto * inp_attn = build_attn_inp_kv_unified_iswa();
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * x = inpL;
|
||||
|
||||
// Pre attention Layer norm
|
||||
// pre attn LayerNorm
|
||||
ggml_tensor * x_attn_in = x;
|
||||
if (model.layers[il].attn_norm) {
|
||||
x_attn_in = build_norm(x, model.layers[il].attn_norm, model.layers[il].attn_norm_b, LLM_NORM, il);
|
||||
}
|
||||
|
||||
// fused qkv
|
||||
// fused QKV
|
||||
GGML_ASSERT(model.layers[il].wqkv);
|
||||
ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, x_attn_in);
|
||||
if (model.layers[il].bqkv) {
|
||||
|
|
@ -7609,41 +7621,120 @@ struct llm_build_modern_bert : public llm_graph_context {
|
|||
if (model.layers[il].attn_q_norm) Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il);
|
||||
if (model.layers[il].attn_k_norm) Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il);
|
||||
|
||||
// reshape for multi head
|
||||
// reshape for multi-head attention
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// rope embedding
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
// global or local layer
|
||||
bool is_global = ((il + 1) % 3 == 0);
|
||||
float freq_base_l = is_global ? 160000.0f : 10000.0f; // rope theta
|
||||
float freq_scale_l = 1.0f;
|
||||
|
||||
ggml_tensor * pos_q = inp_pos_global;
|
||||
|
||||
ggml_tensor * K_work = Kcur;
|
||||
ggml_tensor * V_work = Vcur;
|
||||
ggml_tensor * pos_k = inp_pos_global;
|
||||
|
||||
if (!is_global) {
|
||||
ggml_tensor * idx_src = inp_attn->self_k_idxs_swa;
|
||||
|
||||
ggml_tensor * idx_view1d = ggml_view_1d(ctx0, idx_src, idx_src->ne[0], 0);
|
||||
ggml_tensor * idx_cont = ggml_cont(ctx0, idx_view1d);
|
||||
|
||||
ggml_tensor * idx_i32 = idx_cont;
|
||||
if (idx_i32->type != GGML_TYPE_I32) {
|
||||
idx_i32 = ggml_cast(ctx0, idx_cont, GGML_TYPE_I32);
|
||||
}
|
||||
|
||||
const int64_t n_indices = idx_i32->ne[0];
|
||||
ggml_tensor * idx_2d = ggml_view_2d(ctx0, idx_i32, 1, n_indices, sizeof(int32_t), 0);
|
||||
|
||||
idx_2d = ggml_cont(ctx0, idx_2d);
|
||||
if (idx_2d->type != GGML_TYPE_I32) idx_2d = ggml_cast(ctx0, idx_2d, GGML_TYPE_I32);
|
||||
|
||||
Kcur->ne[0], Kcur->ne[1], Kcur->ne[2],
|
||||
idx_2d->ne[0], idx_2d->ne[1], idx_2d->ne[2], idx_2d->ne[3],
|
||||
idx_2d->type);
|
||||
|
||||
K_work = ggml_get_rows(ctx0, Kcur, idx_2d);
|
||||
V_work = ggml_get_rows(ctx0, Vcur, idx_2d);
|
||||
|
||||
|
||||
|
||||
ggml_tensor * pos_rows = ggml_get_rows(ctx0, inp_pos_global, idx_2d);
|
||||
|
||||
if (!ggml_is_vector(pos_rows)) {
|
||||
const int64_t n_el = ggml_nelements(pos_rows);
|
||||
pos_rows = ggml_view_1d(ctx0, pos_rows, n_el, 0);
|
||||
pos_rows = ggml_cont(ctx0, pos_rows);
|
||||
} else {
|
||||
pos_rows = ggml_cont(ctx0, pos_rows);
|
||||
}
|
||||
// ensure I32
|
||||
if (pos_rows->type != GGML_TYPE_I32) {
|
||||
pos_rows = ggml_cast(ctx0, pos_rows, GGML_TYPE_I32);
|
||||
}
|
||||
|
||||
// final pos_k to pass to rope
|
||||
pos_k = pos_rows;
|
||||
LLAMA_LOG_INFO("pos_k final: ne[0]=%lld, type=%d\n", pos_k->ne[0], pos_k->type);
|
||||
}
|
||||
|
||||
if( !ggml_is_vector(pos_q) ) {
|
||||
const int64_t n_el = ggml_nelements(pos_q);
|
||||
pos_q = ggml_view_1d(ctx0, pos_q, n_el, 0);
|
||||
pos_q = ggml_cont(ctx0, pos_q);
|
||||
}
|
||||
if( !ggml_is_vector(pos_q) ) {
|
||||
}
|
||||
|
||||
|
||||
// apply rope
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, pos_q, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
if( !ggml_is_vector(pos_k) ) {
|
||||
const int64_t n_el = ggml_nelements(pos_k);
|
||||
pos_k = ggml_view_1d(ctx0, pos_k, n_el, 0);
|
||||
pos_k = ggml_cont(ctx0, pos_k);
|
||||
}
|
||||
|
||||
K_work = ggml_rope_ext(ctx0, K_work, pos_k, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
// choseing mask, global vs swa
|
||||
ggml_tensor * kq_b_layer = is_global ? inp_attn->self_kq_mask : inp_attn->self_kq_mask_swa;
|
||||
|
||||
ggml_tensor * attn_out = build_attn(
|
||||
inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur,
|
||||
/*k cache*/ nullptr,
|
||||
/*v cache*/ nullptr,
|
||||
model.layers[il].wo,
|
||||
model.layers[il].bo,
|
||||
Qcur,
|
||||
K_work,
|
||||
V_work,
|
||||
kq_b_layer,
|
||||
nullptr,
|
||||
1.0f / sqrtf(float(n_embd_head)),
|
||||
il
|
||||
);
|
||||
|
||||
// residual addition
|
||||
ggml_tensor * cur_attn = ggml_add(ctx0, attn_out, x);
|
||||
|
||||
// optional subselect output tokens (inp_out_ids)
|
||||
// optional output select
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur_attn = ggml_get_rows(ctx0, cur_attn, inp_out_ids);
|
||||
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
||||
}
|
||||
|
||||
// pre mlp LayerNorm
|
||||
// pre mlp layer norm
|
||||
ggml_tensor * h = build_norm(cur_attn, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, il);
|
||||
|
||||
// geglu FFN
|
||||
// geglu ffn
|
||||
ggml_tensor * mlp_out = build_ffn(
|
||||
h,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
|
|
@ -7653,10 +7744,11 @@ struct llm_build_modern_bert : public llm_graph_context {
|
|||
LLM_FFN_GEGLU, LLM_FFN_PAR, il
|
||||
);
|
||||
|
||||
// resid addition
|
||||
// resudi addition after FFN
|
||||
inpL = ggml_add(ctx0, mlp_out, cur_attn);
|
||||
}
|
||||
|
||||
|
||||
ggml_tensor * cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1);
|
||||
res->t_embd = cur;
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
|
|
|||
Loading…
Reference in New Issue