alternating rope implemented and modern bert graph build succeeds

This commit is contained in:
ryan-mangeno 2025-09-11 16:37:18 -04:00
parent e296a0b6e6
commit 2bacfb0bc2
1 changed files with 78 additions and 167 deletions

View File

@ -7551,209 +7551,117 @@ struct llm_build_bert : public llm_graph_context {
};
struct llm_build_modern_bert : public llm_graph_context {
llm_build_modern_bert(const llama_model & model, const llm_graph_params & params)
: llm_graph_context(params) {
const int64_t n_embd = hparams.n_embd;
const int64_t n_layer = hparams.n_layer;
const int64_t n_head = hparams.n_head();
const int64_t n_head_kv = hparams.n_head_kv();
llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
const int64_t n_local_swa = hparams.n_swa;
const int64_t n_tokens = ubatch.n_tokens;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
// rope params
const int32_t rope_type = LLAMA_ROPE_TYPE_NEOX;
const int32_t n_rot = hparams.n_rot;
const int32_t n_ctx_orig = hparams.n_ctx_train;
const float freq_base = hparams.rope_freq_base_train;
const float freq_scale = hparams.rope_freq_scale_train;
const float attn_factor = 1.0f;
const float ext_factor = 1.0f;
const float beta_fast = 0.0f;
const float beta_slow = 0.0f;
ggml_tensor * cur;
ggml_tensor * inpL;
ggml_tensor * inp_pos = build_inp_pos(); // Initialize inp_pos with build_inp_pos()
// construct input embeddings (token, type, position)
inpL = build_inp_embd(model.tok_embd);
cb(inpL, "inp_embd", -1);
ggml_tensor *inp_pos_global = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, 4096, 1);
ggml_set_input(inp_pos_global);
size_t element_size = ggml_type_size(inp_pos_global->type);
size_t nb1 = element_size;
size_t nb2 = nb1;
inp_pos_global = ggml_view_3d(ctx0, inp_pos_global, 1, 1, 4096, nb1, nb2, 0);
inp_pos_global = ggml_cont(ctx0, inp_pos_global);
ggml_tensor * inpL = build_inp_embd(model.tok_embd);
if (model.type_embd) {
inpL = ggml_add(ctx0, inpL, ggml_view_1d(ctx0, model.type_embd, n_embd, 0));
}
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
// embed layer norm
inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1);
cb(inpL, "inp_norm", -1);
auto * inp_attn = build_attn_inp_kv_unified_iswa();
ggml_tensor * inp_out_ids = build_inp_out_ids();
// iterate layers
for (int il = 0; il < n_layer; ++il) {
LLAMA_LOG_INFO("Setting layer %d\n", il);
ggml_tensor * x = inpL;
ggml_tensor * cur = inpL;
// pre attn LayerNorm
ggml_tensor * x_attn_in = x;
ggml_tensor * Qcur;
ggml_tensor * Kcur;
ggml_tensor * Vcur;
float rope_theta = il % 3 == 0 ? hparams.rope_freq_base_train : hparams.rope_freq_base_train_swa;
// attention layer norm
if (model.layers[il].attn_norm) {
x_attn_in = build_norm(x, model.layers[il].attn_norm, model.layers[il].attn_norm_b, LLM_NORM, il);
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM, il);
cb(cur, "attn_norm", il);
}
// fused QKV
GGML_ASSERT(model.layers[il].wqkv);
ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, x_attn_in);
if (model.layers[il].bqkv) {
qkv = ggml_add(ctx0, qkv, model.layers[il].bqkv);
}
// self attention
cur = build_lora_mm(model.layers[il].wqkv, cur);
cb(cur, "wqkv", il);
ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd, n_tokens, qkv->nb[1], 0));
ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], n_embd));
ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], n_embd + n_embd_gqa));
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
// optional q/k LayerNorm
if (model.layers[il].attn_q_norm) Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il);
if (model.layers[il].attn_k_norm) Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il);
// reshape for multi-head attention
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
// global or local layer
bool is_global = ((il + 1) % 3 == 0);
float freq_base_l = is_global ? 160000.0f : 10000.0f; // rope theta
float freq_scale_l = 1.0f;
// RoPE
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, rope_theta, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
ggml_tensor * pos_q = inp_pos_global;
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, rope_theta, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
ggml_tensor * K_work = Kcur;
ggml_tensor * V_work = Vcur;
ggml_tensor * pos_k = inp_pos_global;
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
if (!is_global) {
ggml_tensor * idx_src = inp_attn->self_k_idxs_swa;
cur = build_attn(inp_attn,
model.layers[il].wo, nullptr,
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
cb(cur, "kqv_out", il);
ggml_tensor * idx_view1d = ggml_view_1d(ctx0, idx_src, idx_src->ne[0], 0);
ggml_tensor * idx_cont = ggml_cont(ctx0, idx_view1d);
ggml_tensor * idx_i32 = idx_cont;
if (idx_i32->type != GGML_TYPE_I32) {
idx_i32 = ggml_cast(ctx0, idx_cont, GGML_TYPE_I32);
}
const int64_t n_indices = idx_i32->ne[0];
ggml_tensor * idx_2d = ggml_view_2d(ctx0, idx_i32, 1, n_indices, sizeof(int32_t), 0);
idx_2d = ggml_cont(ctx0, idx_2d);
if (idx_2d->type != GGML_TYPE_I32) idx_2d = ggml_cast(ctx0, idx_2d, GGML_TYPE_I32);
K_work = ggml_get_rows(ctx0, Kcur, idx_2d);
V_work = ggml_get_rows(ctx0, Vcur, idx_2d);
ggml_tensor * pos_rows = ggml_get_rows(ctx0, inp_pos_global, idx_2d);
if (!ggml_is_vector(pos_rows)) {
const int64_t n_el = ggml_nelements(pos_rows);
pos_rows = ggml_view_1d(ctx0, pos_rows, n_el, 0);
pos_rows = ggml_cont(ctx0, pos_rows);
} else {
pos_rows = ggml_cont(ctx0, pos_rows);
}
// ensure I32
if (pos_rows->type != GGML_TYPE_I32) {
pos_rows = ggml_cast(ctx0, pos_rows, GGML_TYPE_I32);
}
// final pos_k to pass to rope
pos_k = pos_rows;
}
if( !ggml_is_vector(pos_q) ) {
const int64_t n_el = ggml_nelements(pos_q);
pos_q = ggml_view_1d(ctx0, pos_q, n_el, 0);
pos_q = ggml_cont(ctx0, pos_q);
if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
// skip computing output for unused tokens
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
// apply rope
Qcur = ggml_rope_ext(ctx0, Qcur, pos_q, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
// re-add the layer input
cur = ggml_add(ctx0, cur, inpL);
if( !ggml_is_vector(pos_k) ) {
const int64_t n_el = ggml_nelements(pos_k);
pos_k = ggml_view_1d(ctx0, pos_k, n_el, 0);
pos_k = ggml_cont(ctx0, pos_k);
}
// attention layer norm
cur = build_norm(cur, model.layers[il].attn_out_norm, nullptr, LLM_NORM, il);
K_work = ggml_rope_ext(ctx0, K_work, pos_k, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
ggml_tensor * ffn_inp = cur;
cb(ffn_inp, "ffn_inp", il);
// choseing mask, global vs swa
ggml_tensor * kq_mask = is_global ? inp_attn->self_kq_mask : inp_attn->self_kq_mask_swa;
cur = build_ffn(cur,
model.layers[il].ffn_up,
NULL, NULL, NULL, NULL, NULL,
model.layers[il].ffn_down,
NULL, NULL, NULL,
LLM_FFN_GEGLU, LLM_FFN_SEQ, il);
// flatten K/V back to full embedding dim
int64_t n_embd = n_embd_head * n_head_kv;
int64_t n_tokens = Kcur->ne[2];
// attentions bypass the intermediate layer
cur = ggml_add(ctx0, cur, ffn_inp);
ggml_tensor *K_2d = ggml_reshape_2d(ctx0, Kcur, n_embd, n_tokens);
ggml_tensor *K_flat = ggml_view_3d(ctx0, K_2d, n_embd, 1, n_tokens,
K_2d->nb[0], K_2d->nb[1], 0);
K_flat = ggml_cont(ctx0, K_flat);
ggml_tensor * V_flat = ggml_reshape_2d(ctx0, Vcur, n_embd, n_tokens);
ggml_tensor * attn_out = build_attn(
inp_attn,
model.layers[il].wo,
model.layers[il].bo,
Qcur,
K_flat,
V_flat,
kq_mask,
nullptr,
1.0f / sqrtf(float(n_embd_head)),
il
);
ggml_tensor * cur_attn = ggml_add(ctx0, x, attn_out);
// optional output select
if (il == n_layer - 1 && inp_out_ids) {
cur_attn = ggml_get_rows(ctx0, cur_attn, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
// pre mlp layer norm
ggml_tensor * h = build_norm(cur_attn, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, il);
// geglu ffn
ggml_tensor * mlp_out = build_ffn(
h,
model.layers[il].ffn_up, NULL, NULL,
NULL, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_GEGLU, LLM_FFN_PAR, il
);
// resudi addition after FFN
inpL = ggml_add(ctx0, mlp_out, cur_attn);
// input for next layer
inpL = cur;
}
cur = inpL;
ggml_tensor * cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1);
cur = build_norm(cur,
model.output_norm_enc, NULL,
LLM_NORM, -1);
cb(cur, "result_embd", -1);
res->t_embd = cur;
ggml_build_forward_expand(gf, cur);
}
};
@ -18450,6 +18358,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
case LLM_ARCH_MODERN_BERT:
{
llm = std::make_unique<llm_build_modern_bert>(*this, params);
LLAMA_LOG_INFO("Built llm\n");
} break;
case LLM_ARCH_NEO_BERT:
{
@ -18768,6 +18677,8 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
GGML_ABORT("fatal error");
}
LLAMA_LOG_INFO("Building pooling\n");
// add on pooling layer
llm->build_pooling(cls, cls_b, cls_out, cls_out_b);