diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 34cd49083b..6aa1426a28 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7551,209 +7551,117 @@ struct llm_build_bert : public llm_graph_context { }; struct llm_build_modern_bert : public llm_graph_context { - llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) - : llm_graph_context(params) { - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_head = hparams.n_head(); - const int64_t n_head_kv = hparams.n_head_kv(); + llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - const int64_t n_local_swa = hparams.n_swa; - const int64_t n_tokens = ubatch.n_tokens; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - // rope params - const int32_t rope_type = LLAMA_ROPE_TYPE_NEOX; - const int32_t n_rot = hparams.n_rot; - const int32_t n_ctx_orig = hparams.n_ctx_train; - const float freq_base = hparams.rope_freq_base_train; - const float freq_scale = hparams.rope_freq_scale_train; - const float attn_factor = 1.0f; - const float ext_factor = 1.0f; - const float beta_fast = 0.0f; - const float beta_slow = 0.0f; + ggml_tensor * cur; + ggml_tensor * inpL; + ggml_tensor * inp_pos = build_inp_pos(); // Initialize inp_pos with build_inp_pos() + // construct input embeddings (token, type, position) + inpL = build_inp_embd(model.tok_embd); + cb(inpL, "inp_embd", -1); - ggml_tensor *inp_pos_global = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, 4096, 1); - ggml_set_input(inp_pos_global); - size_t element_size = ggml_type_size(inp_pos_global->type); - - size_t nb1 = element_size; - size_t nb2 = nb1; - - inp_pos_global = ggml_view_3d(ctx0, inp_pos_global, 1, 1, 4096, nb1, nb2, 0); - inp_pos_global = ggml_cont(ctx0, inp_pos_global); - - ggml_tensor * inpL = build_inp_embd(model.tok_embd); - - if (model.type_embd) { - inpL = ggml_add(ctx0, inpL, ggml_view_1d(ctx0, model.type_embd, n_embd, 0)); - } - inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); + // embed layer norm + inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1); + cb(inpL, "inp_norm", -1); auto * inp_attn = build_attn_inp_kv_unified_iswa(); - ggml_tensor * inp_out_ids = build_inp_out_ids(); - + // iterate layers for (int il = 0; il < n_layer; ++il) { - LLAMA_LOG_INFO("Setting layer %d\n", il); - ggml_tensor * x = inpL; + ggml_tensor * cur = inpL; - // pre attn LayerNorm - ggml_tensor * x_attn_in = x; + ggml_tensor * Qcur; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + + float rope_theta = il % 3 == 0 ? hparams.rope_freq_base_train : hparams.rope_freq_base_train_swa; + + // attention layer norm if (model.layers[il].attn_norm) { - x_attn_in = build_norm(x, model.layers[il].attn_norm, model.layers[il].attn_norm_b, LLM_NORM, il); + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM, il); + cb(cur, "attn_norm", il); } - // fused QKV - GGML_ASSERT(model.layers[il].wqkv); - ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, x_attn_in); - if (model.layers[il].bqkv) { - qkv = ggml_add(ctx0, qkv, model.layers[il].bqkv); - } + // self attention + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd, n_tokens, qkv->nb[1], 0)); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], n_embd)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], n_embd + n_embd_gqa)); + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - // optional q/k LayerNorm - if (model.layers[il].attn_q_norm) Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il); - if (model.layers[il].attn_k_norm) Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il); - - // reshape for multi-head attention Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - // global or local layer - bool is_global = ((il + 1) % 3 == 0); - float freq_base_l = is_global ? 160000.0f : 10000.0f; // rope theta - float freq_scale_l = 1.0f; + // RoPE + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, rope_theta, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); - ggml_tensor * pos_q = inp_pos_global; + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, rope_theta, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); - ggml_tensor * K_work = Kcur; - ggml_tensor * V_work = Vcur; - ggml_tensor * pos_k = inp_pos_global; + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); - if (!is_global) { - ggml_tensor * idx_src = inp_attn->self_k_idxs_swa; + cur = build_attn(inp_attn, + model.layers[il].wo, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + cb(cur, "kqv_out", il); - ggml_tensor * idx_view1d = ggml_view_1d(ctx0, idx_src, idx_src->ne[0], 0); - ggml_tensor * idx_cont = ggml_cont(ctx0, idx_view1d); - - ggml_tensor * idx_i32 = idx_cont; - if (idx_i32->type != GGML_TYPE_I32) { - idx_i32 = ggml_cast(ctx0, idx_cont, GGML_TYPE_I32); - } - - const int64_t n_indices = idx_i32->ne[0]; - ggml_tensor * idx_2d = ggml_view_2d(ctx0, idx_i32, 1, n_indices, sizeof(int32_t), 0); - - idx_2d = ggml_cont(ctx0, idx_2d); - if (idx_2d->type != GGML_TYPE_I32) idx_2d = ggml_cast(ctx0, idx_2d, GGML_TYPE_I32); - - K_work = ggml_get_rows(ctx0, Kcur, idx_2d); - V_work = ggml_get_rows(ctx0, Vcur, idx_2d); - - ggml_tensor * pos_rows = ggml_get_rows(ctx0, inp_pos_global, idx_2d); - - if (!ggml_is_vector(pos_rows)) { - const int64_t n_el = ggml_nelements(pos_rows); - pos_rows = ggml_view_1d(ctx0, pos_rows, n_el, 0); - pos_rows = ggml_cont(ctx0, pos_rows); - } else { - pos_rows = ggml_cont(ctx0, pos_rows); - } - // ensure I32 - if (pos_rows->type != GGML_TYPE_I32) { - pos_rows = ggml_cast(ctx0, pos_rows, GGML_TYPE_I32); - } - - // final pos_k to pass to rope - pos_k = pos_rows; - } - - if( !ggml_is_vector(pos_q) ) { - const int64_t n_el = ggml_nelements(pos_q); - pos_q = ggml_view_1d(ctx0, pos_q, n_el, 0); - pos_q = ggml_cont(ctx0, pos_q); + if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } - // apply rope - Qcur = ggml_rope_ext(ctx0, Qcur, pos_q, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, - ext_factor, attn_factor, beta_fast, beta_slow); + // re-add the layer input + cur = ggml_add(ctx0, cur, inpL); - if( !ggml_is_vector(pos_k) ) { - const int64_t n_el = ggml_nelements(pos_k); - pos_k = ggml_view_1d(ctx0, pos_k, n_el, 0); - pos_k = ggml_cont(ctx0, pos_k); - } + // attention layer norm + cur = build_norm(cur, model.layers[il].attn_out_norm, nullptr, LLM_NORM, il); - K_work = ggml_rope_ext(ctx0, K_work, pos_k, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, - ext_factor, attn_factor, beta_fast, beta_slow); + ggml_tensor * ffn_inp = cur; + cb(ffn_inp, "ffn_inp", il); - // choseing mask, global vs swa - ggml_tensor * kq_mask = is_global ? inp_attn->self_kq_mask : inp_attn->self_kq_mask_swa; + cur = build_ffn(cur, + model.layers[il].ffn_up, + NULL, NULL, NULL, NULL, NULL, + model.layers[il].ffn_down, + NULL, NULL, NULL, + LLM_FFN_GEGLU, LLM_FFN_SEQ, il); - // flatten K/V back to full embedding dim - int64_t n_embd = n_embd_head * n_head_kv; - int64_t n_tokens = Kcur->ne[2]; + // attentions bypass the intermediate layer + cur = ggml_add(ctx0, cur, ffn_inp); - ggml_tensor *K_2d = ggml_reshape_2d(ctx0, Kcur, n_embd, n_tokens); - - ggml_tensor *K_flat = ggml_view_3d(ctx0, K_2d, n_embd, 1, n_tokens, - K_2d->nb[0], K_2d->nb[1], 0); - K_flat = ggml_cont(ctx0, K_flat); - ggml_tensor * V_flat = ggml_reshape_2d(ctx0, Vcur, n_embd, n_tokens); - - - ggml_tensor * attn_out = build_attn( - inp_attn, - model.layers[il].wo, - model.layers[il].bo, - Qcur, - K_flat, - V_flat, - kq_mask, - nullptr, - 1.0f / sqrtf(float(n_embd_head)), - il - ); - - - ggml_tensor * cur_attn = ggml_add(ctx0, x, attn_out); - - // optional output select - if (il == n_layer - 1 && inp_out_ids) { - cur_attn = ggml_get_rows(ctx0, cur_attn, inp_out_ids); - inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); - } - - // pre mlp layer norm - ggml_tensor * h = build_norm(cur_attn, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, il); - - // geglu ffn - ggml_tensor * mlp_out = build_ffn( - h, - model.layers[il].ffn_up, NULL, NULL, - NULL, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_GEGLU, LLM_FFN_PAR, il - ); - - // resudi addition after FFN - inpL = ggml_add(ctx0, mlp_out, cur_attn); + // input for next layer + inpL = cur; } + cur = inpL; - ggml_tensor * cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1); + cur = build_norm(cur, + model.output_norm_enc, NULL, + LLM_NORM, -1); + + cb(cur, "result_embd", -1); res->t_embd = cur; + ggml_build_forward_expand(gf, cur); } }; @@ -18450,6 +18358,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { case LLM_ARCH_MODERN_BERT: { llm = std::make_unique(*this, params); + LLAMA_LOG_INFO("Built llm\n"); } break; case LLM_ARCH_NEO_BERT: { @@ -18768,6 +18677,8 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { GGML_ABORT("fatal error"); } + LLAMA_LOG_INFO("Building pooling\n"); + // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b);