diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 04b76f92cf..62e68a912d 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -263,8 +263,7 @@ static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_s const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" : (swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" : (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" : - (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : - (swa_type == LLAMA_SWA_TYPE_LOCAL) ? "LLAMA_SWA_TYPE_LOCAL" : "unknown"; + (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown"; LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str); LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__); LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__); diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 81325a9ccf..116d728e8c 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -20,7 +20,6 @@ enum llama_swa_type { LLAMA_SWA_TYPE_STANDARD = 1, LLAMA_SWA_TYPE_CHUNKED = 2, LLAMA_SWA_TYPE_SYMMETRIC = 3, - LLAMA_SWA_TYPE_LOCAL = 4, }; struct llama_hparams_posnet { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c6ad8d2ac5..726e614a29 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -786,7 +786,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_MODERN_BERT: { - hparams.swa_type = LLAMA_SWA_TYPE_LOCAL; + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(3, 0); hparams.rope_freq_base_train_swa = 10000.f; @@ -7845,15 +7845,16 @@ template struct llm_build_modern_bert : public llm_graph_context { llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); const float rope_theta_global = hparams.rope_freq_base_train; const float rope_theta_local = hparams.rope_freq_base_train_swa; + //const uint32_t n_swa_local = hparams.n_swa; + //const uint32_t n_swa_global = 4096; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - ggml_tensor * cur; - ggml_tensor * inpL; - ggml_tensor * inp_pos = build_inp_pos(); // Initialize inp_pos with build_inp_pos() + ggml_tensor * cur = nullptr; + ggml_tensor * inpL = nullptr; + ggml_tensor * inp_pos = build_inp_pos(); // construct input embeddings (token, type, position) inpL = build_inp_embd(model.tok_embd); @@ -7869,11 +7870,12 @@ struct llm_build_modern_bert : public llm_graph_context { for (int il = 0; il < n_layer; ++il) { ggml_tensor * cur = inpL; - ggml_tensor * Qcur; - ggml_tensor * Kcur; - ggml_tensor * Vcur; + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + + const float rope_theta = (il+1) % 3 == 0 ? rope_theta_global : rope_theta_local; - const float rope_theta = il % 3 == 0 ? rope_theta_global : rope_theta_local; // attention layer norm if (model.layers[il].attn_norm) { @@ -7887,9 +7889,9 @@ struct llm_build_modern_bert : public llm_graph_context { cur = build_lora_mm(model.layers[il].wqkv, cur); cb(cur, "wqkv", il); - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd*2))); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index e1e0d66fbb..39fa446f56 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1864,8 +1864,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "jina-v2-es" || tokenizer_pre == "jina-v2-de" || tokenizer_pre == "a.x-4.0" || - tokenizer_pre == "mellum" || - tokenizer_pre == "modern-bert") { + tokenizer_pre == "mellum" || + tokenizer_pre == "modern-bert" ) { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; } else if ( tokenizer_pre == "jina-v1-en" ||